mirror of https://github.com/tracel-ai/burn.git
Remove topological sort of nodes (#589)
ONNX nodes should come as topologically sorted, so we are removing it instead, we are making sure the nodes are topologically sorted.
This commit is contained in:
parent
ce8a175aa4
commit
5cc32cc8cb
|
@ -55,7 +55,6 @@ strum_macros = "0.24"
|
||||||
syn = {version = "2.0", features = ["full", "extra-traits"]}
|
syn = {version = "2.0", features = ["full", "extra-traits"]}
|
||||||
tempfile = "3.6.0"
|
tempfile = "3.6.0"
|
||||||
thiserror = "1.0.40"
|
thiserror = "1.0.40"
|
||||||
topological-sort = "0.2.2"
|
|
||||||
|
|
||||||
# WGPU stuff
|
# WGPU stuff
|
||||||
futures-intrusive = "0.5"
|
futures-intrusive = "0.5"
|
||||||
|
|
|
@ -35,7 +35,6 @@ serde_json = {workspace = true, features = ["std"]}
|
||||||
strum = {workspace = true}
|
strum = {workspace = true}
|
||||||
strum_macros = {workspace = true}
|
strum_macros = {workspace = true}
|
||||||
syn = {workspace = true, features = ["parsing"]}
|
syn = {workspace = true, features = ["parsing"]}
|
||||||
topological-sort = {workspace = true}
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
protobuf-codegen = {workspace = true}
|
protobuf-codegen = {workspace = true}
|
||||||
|
|
|
@ -455,8 +455,12 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
||||||
});
|
});
|
||||||
|
|
||||||
output_names.iter().for_each(|output| {
|
output_names.iter().for_each(|output| {
|
||||||
self.graph_output_types
|
self.graph_output_types.push(
|
||||||
.push(outputs.get(output).unwrap().clone());
|
outputs
|
||||||
|
.get(output)
|
||||||
|
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
|
||||||
|
.clone(),
|
||||||
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::HashMap,
|
||||||
fs::File,
|
fs::File,
|
||||||
path::Path,
|
path::Path,
|
||||||
str::{from_utf8, FromStr},
|
str::{from_utf8, FromStr},
|
||||||
|
@ -19,7 +19,6 @@ use super::{coalesce::coalesce, ir::StateType};
|
||||||
|
|
||||||
use bytemuck::cast_slice;
|
use bytemuck::cast_slice;
|
||||||
use protobuf::{Enum, Message};
|
use protobuf::{Enum, Message};
|
||||||
use topological_sort::TopologicalSort;
|
|
||||||
|
|
||||||
/// Error type for parsing ONNX model
|
/// Error type for parsing ONNX model
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -28,6 +27,20 @@ pub enum ParseError {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Open an onnx file and convert it to a Graph (intermediate representation)
|
/// Open an onnx file and convert it to a Graph (intermediate representation)
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `onnx_path` - Path to the onnx file
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `ONNXGraph` - The graph representation of the onnx file
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// * If the file cannot be opened
|
||||||
|
/// * If the file cannot be parsed
|
||||||
|
/// * If the nodes are not topologically sorted
|
||||||
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
||||||
log::info!("Parsing ONNX file: {}", onnx_path.display());
|
log::info!("Parsing ONNX file: {}", onnx_path.display());
|
||||||
|
|
||||||
|
@ -52,29 +65,35 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
||||||
nodes.push(convert_node_proto(onnx_node));
|
nodes.push(convert_node_proto(onnx_node));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ONNX nodes must be topologically sorted per spec:
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs
|
||||||
|
assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted");
|
||||||
|
|
||||||
// Move inputs to initializers
|
// Move inputs to initializers
|
||||||
move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer);
|
move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer);
|
||||||
|
|
||||||
// Get the topological sort of the nodes and the top nodes
|
|
||||||
top_sort_nodes(&mut nodes);
|
|
||||||
|
|
||||||
// Collect inputs, outputs and initializers
|
|
||||||
let check_if_initializer: HashSet<String> = onnx_model
|
|
||||||
.graph
|
|
||||||
.initializer
|
|
||||||
.iter()
|
|
||||||
.map(|x| x.name.clone())
|
|
||||||
.collect();
|
|
||||||
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer);
|
|
||||||
|
|
||||||
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
|
|
||||||
let states = collect_states(onnx_model);
|
|
||||||
|
|
||||||
// Coalesce and transform nodes
|
// Coalesce and transform nodes
|
||||||
coalesce(&mut nodes);
|
coalesce(&mut nodes);
|
||||||
|
|
||||||
// Rename nodes and inputs, save the mapping for later
|
// Rename nodes and inputs, save the mapping for later
|
||||||
let old_node_names = rename_nodes(&mut nodes);
|
let old_node_names = rename_nodes(&mut nodes);
|
||||||
|
|
||||||
|
// This function collects the inputs of an ONNX model and returns them as a vector of Arguments.
|
||||||
|
let mut inputs = onnx_model
|
||||||
|
.graph
|
||||||
|
.input
|
||||||
|
.iter()
|
||||||
|
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Map each output in the model's graph to an Argument and collect them into a vector.
|
||||||
|
let mut outputs = onnx_model
|
||||||
|
.graph
|
||||||
|
.output
|
||||||
|
.iter()
|
||||||
|
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs);
|
let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs);
|
||||||
|
|
||||||
// Infer shapes and update the inputs and outputs
|
// Infer shapes and update the inputs and outputs
|
||||||
|
@ -86,71 +105,11 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
||||||
nodes,
|
nodes,
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
states,
|
|
||||||
old_node_names,
|
old_node_names,
|
||||||
old_input_names,
|
old_input_names,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Collect initializers
|
|
||||||
fn collect_states(onnx_model: ModelProto) -> Vec<State> {
|
|
||||||
let mut initializers = Vec::new();
|
|
||||||
|
|
||||||
for initializer in onnx_model.graph.initializer.iter() {
|
|
||||||
let tensor_proto = initializer.clone();
|
|
||||||
let name = tensor_proto.name.clone();
|
|
||||||
let tensor: Tensor = tensor_proto.try_into().unwrap();
|
|
||||||
let ty = StateType::Tensor(tensor);
|
|
||||||
let arg = State { name, ty };
|
|
||||||
|
|
||||||
initializers.push(arg);
|
|
||||||
}
|
|
||||||
initializers
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Collect outputs
|
|
||||||
fn collect_outputs(
|
|
||||||
onnx_model: &ModelProto,
|
|
||||||
check_if_initializer: HashSet<String>,
|
|
||||||
) -> Vec<Argument> {
|
|
||||||
let outputs: Vec<Argument> = onnx_model
|
|
||||||
.graph
|
|
||||||
.output
|
|
||||||
.iter()
|
|
||||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
|
||||||
.map(|i| Argument::try_from(i.clone()).unwrap())
|
|
||||||
.collect();
|
|
||||||
outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Collect inputs
|
|
||||||
fn collect_inputs(
|
|
||||||
onnx_model: &ModelProto,
|
|
||||||
check_if_initializer: &HashSet<String>,
|
|
||||||
) -> Vec<Argument> {
|
|
||||||
// Get the unique inputs
|
|
||||||
let inputs: Vec<Argument> = onnx_model
|
|
||||||
.graph
|
|
||||||
.input
|
|
||||||
.iter()
|
|
||||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
|
||||||
// .filter(|x| top_nodes.contains(&x.name))
|
|
||||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Convert to a vector and return
|
|
||||||
inputs.into_iter().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sort the nodes in topological order
|
|
||||||
fn top_sort_nodes(nodes: &mut Vec<Node>) {
|
|
||||||
let mut ts = topsort(nodes);
|
|
||||||
*nodes = vec![];
|
|
||||||
while let Some(node) = ts.pop() {
|
|
||||||
nodes.push(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_string(bytes: Vec<u8>) -> String {
|
fn to_string(bytes: Vec<u8>) -> String {
|
||||||
from_utf8(bytes.as_slice()).unwrap().to_string()
|
from_utf8(bytes.as_slice()).unwrap().to_string()
|
||||||
}
|
}
|
||||||
|
@ -426,14 +385,23 @@ impl TryFrom<ValueInfoProto> for State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function moves inputs that are also present in the initializer to the node's states vector.
|
||||||
|
// It also removes inputs that are already present in the states vector.
|
||||||
fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
||||||
|
// Iterate over each node in the graph
|
||||||
nodes.iter_mut().for_each(|node| {
|
nodes.iter_mut().for_each(|node| {
|
||||||
|
// Create a new vector to hold the node's states
|
||||||
let mut node_states = Vec::new();
|
let mut node_states = Vec::new();
|
||||||
|
// Create a new vector to hold the node's inputs
|
||||||
let mut inputs = Vec::new();
|
let mut inputs = Vec::new();
|
||||||
|
|
||||||
|
// Iterate over each input in the node's inputs vector
|
||||||
for input in node.inputs.iter() {
|
for input in node.inputs.iter() {
|
||||||
|
// Iterate over each tensor in the initializer
|
||||||
for init in initializer.iter() {
|
for init in initializer.iter() {
|
||||||
|
// If the input name matches the tensor name in the initializer
|
||||||
if init.name == input.name {
|
if init.name == input.name {
|
||||||
|
// Add the tensor to the node's states vector
|
||||||
node_states.push(State {
|
node_states.push(State {
|
||||||
name: init.name.clone(),
|
name: init.name.clone(),
|
||||||
ty: StateType::Tensor(init.clone().try_into().unwrap()),
|
ty: StateType::Tensor(init.clone().try_into().unwrap()),
|
||||||
|
@ -442,7 +410,10 @@ fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Swap the node's inputs vector with the temporary inputs vector
|
||||||
core::mem::swap(&mut inputs, &mut node.inputs);
|
core::mem::swap(&mut inputs, &mut node.inputs);
|
||||||
|
|
||||||
|
// Filter out inputs that are already present in the node's states vector
|
||||||
node.inputs = inputs
|
node.inputs = inputs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|input| {
|
.filter(|input| {
|
||||||
|
@ -455,6 +426,8 @@ fn move_inputs_to_state(nodes: &mut Vec<Node>, initializer: &[TensorProto]) {
|
||||||
true
|
true
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Set the node's states vector to the temporary node_states vector
|
||||||
node.states = node_states;
|
node.states = node_states;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -482,7 +455,7 @@ fn rename_nodes(nodes: &mut Vec<Node>) -> HashMap<String, String> {
|
||||||
old_names
|
old_names
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rename the inputs in the graph and return a map of the old names to the new names.
|
/// Rename the inputs and output in the graph and return a map of the old names to the new names.
|
||||||
///
|
///
|
||||||
/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc.
|
/// The inputs are renamed to be unique and to be in the format of conv2_in1, conv2_in2, etc.
|
||||||
/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers.
|
/// This is done to be consistent with the naming convention of the nodes and allow to be used as rust identifiers.
|
||||||
|
@ -493,13 +466,13 @@ fn rename_inputs(
|
||||||
) -> HashMap<String, String> {
|
) -> HashMap<String, String> {
|
||||||
let mut old_names = HashMap::new();
|
let mut old_names = HashMap::new();
|
||||||
|
|
||||||
|
// rename all graph input names to follow input1, input2, input3, etc.
|
||||||
|
// (assumes the input names are already unique)
|
||||||
let mut counter = 1;
|
let mut counter = 1;
|
||||||
for input in inputs.iter_mut() {
|
for input in inputs.iter_mut() {
|
||||||
let old_name = input.name.clone();
|
let old_name = input.name.clone();
|
||||||
let new_name = format!("input{}", counter);
|
let new_name = format!("input{}", counter);
|
||||||
|
|
||||||
input.name = new_name.clone();
|
input.name = new_name.clone();
|
||||||
|
|
||||||
old_names.insert(old_name, new_name);
|
old_names.insert(old_name, new_name);
|
||||||
counter += 1;
|
counter += 1;
|
||||||
}
|
}
|
||||||
|
@ -513,13 +486,6 @@ fn rename_inputs(
|
||||||
.and_modify(|e| *e += 1)
|
.and_modify(|e| *e += 1)
|
||||||
.or_insert(1);
|
.or_insert(1);
|
||||||
|
|
||||||
// loop through node inputs and rename them with previously replaced names
|
|
||||||
for input in node.inputs.iter_mut() {
|
|
||||||
if let Some(new_name) = old_names.get(&input.name) {
|
|
||||||
input.name = new_name.clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loop through node outputs and rename them and store the new name <-> old name mapping
|
// loop through node outputs and rename them and store the new name <-> old name mapping
|
||||||
for output in node.outputs.iter_mut() {
|
for output in node.outputs.iter_mut() {
|
||||||
let old_name = output.name.clone();
|
let old_name = output.name.clone();
|
||||||
|
@ -529,48 +495,63 @@ fn rename_inputs(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for node in nodes.iter_mut() {
|
||||||
|
// loop through node inputs and rename them with previously replaced names
|
||||||
|
for input in node.inputs.iter_mut() {
|
||||||
|
if let Some(new_name) = old_names.get(&input.name) {
|
||||||
|
input.name = new_name.clone();
|
||||||
|
} else {
|
||||||
|
panic!("Input {} not found in old_names", input.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Rename the graph outputs
|
// Rename the graph outputs
|
||||||
for output in outputs.iter_mut() {
|
for output in outputs.iter_mut() {
|
||||||
if let Some(new_name) = old_names.get(&output.name) {
|
if let Some(new_name) = old_names.get(&output.name) {
|
||||||
output.name = new_name.clone();
|
output.name = new_name.clone();
|
||||||
|
} else {
|
||||||
|
println!("{:#?}", old_names);
|
||||||
|
panic!("Output {} not found in old_names", output.name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
old_names
|
old_names
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find the node that produces the given output
|
// Define a trait for topological sorting
|
||||||
fn lookup_node_by_output(nodes: &Vec<Node>, input: &str) -> Option<Node> {
|
trait TopologicalSortable {
|
||||||
for node in nodes.iter() {
|
fn is_top_sorted(&self) -> bool;
|
||||||
if node.outputs.iter().any(|x| x.name == *input) {
|
|
||||||
return Some(node.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sort nodes in topological order
|
impl TopologicalSortable for Vec<Node> {
|
||||||
pub fn topsort(nodes: &Vec<Node>) -> TopologicalSort<Node> {
|
fn is_top_sorted(&self) -> bool {
|
||||||
if nodes.is_empty() {
|
// Create a hashmap to store the position of each node in the vector
|
||||||
panic!("No nodes to sort");
|
let position: HashMap<String, usize> = self
|
||||||
}
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, node)| (node.name.clone(), idx))
|
||||||
|
.collect();
|
||||||
|
|
||||||
let mut ts = TopologicalSort::new();
|
// Iterate over each node in the vector
|
||||||
|
for node in self {
|
||||||
// If there is only one node, then it is the only dependency
|
// Iterate over each output of the node
|
||||||
if nodes.len() == 1 {
|
for output in &node.outputs {
|
||||||
ts.insert(nodes[0].clone());
|
// Iterate over each other node in the vector
|
||||||
return ts;
|
for other_node in self {
|
||||||
}
|
// If the other node has an input that matches the current output
|
||||||
|
if other_node.inputs.contains(output) {
|
||||||
for node in nodes.iter() {
|
// If the position of the current node is greater than the position of the other node
|
||||||
for input in node.inputs.iter() {
|
if position[&node.name] > position[&other_node.name] {
|
||||||
match lookup_node_by_output(nodes, input.name.as_str()) {
|
// The vector is not topologically sorted
|
||||||
Some(prec) => ts.add_dependency(prec, node.clone()),
|
return false;
|
||||||
None => {}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
ts
|
// The vector is topologically sorted
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,9 +84,6 @@ pub struct ONNXGraph {
|
||||||
/// The outputs of the graph.
|
/// The outputs of the graph.
|
||||||
pub outputs: Vec<Argument>,
|
pub outputs: Vec<Argument>,
|
||||||
|
|
||||||
/// The states of the graph.
|
|
||||||
pub states: Vec<State>,
|
|
||||||
|
|
||||||
/// The original node names.
|
/// The original node names.
|
||||||
pub old_node_names: HashMap<String, String>,
|
pub old_node_names: HashMap<String, String>,
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue