mirror of https://github.com/tracel-ai/burn.git
Optimize argument handling and improve ONNX graph building (#1857)
* draft for alternative burn import design * passes onnx test, fails to build example * pushing to test example on main * fixed the issue with the example * passes the test now * spring cleaning and minor code changes * removed pub visibility from most graph_data fields and functions * comment fixes * went ahead and removed the constant check for now * removed unused function arg
This commit is contained in:
parent
9a32e53e65
commit
effce28b72
|
@ -1,7 +1,7 @@
|
|||
use std::{iter::Peekable, slice::Iter};
|
||||
|
||||
use super::{
|
||||
from_onnx::OnnxGraphIO,
|
||||
from_onnx::GraphData,
|
||||
ir::{AttributeValue, Node, NodeType},
|
||||
proto_conversion::convert_node_proto,
|
||||
protos::NodeProto,
|
||||
|
@ -12,12 +12,12 @@ use crate::onnx::ir::{ArgType, Data, TensorType};
|
|||
pub fn coalesce(
|
||||
node: &mut Node,
|
||||
nodes_iter: &mut Peekable<Iter<NodeProto>>,
|
||||
graph_io: &OnnxGraphIO,
|
||||
graph_data: &GraphData,
|
||||
) {
|
||||
match node.node_type {
|
||||
NodeType::Gemm => convert_gemm_to_linear(node),
|
||||
NodeType::MatMul => {
|
||||
convert_matmul_to_linear(node, nodes_iter, graph_io);
|
||||
convert_matmul_to_linear(node, nodes_iter, graph_data);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ fn transpose_flattened<T: Copy>(matrix: Vec<T>, rows: usize, cols: usize) -> Vec
|
|||
pub(crate) fn convert_matmul_to_linear(
|
||||
node: &mut Node,
|
||||
iter_mut: &mut Peekable<Iter<NodeProto>>,
|
||||
graph_io: &OnnxGraphIO,
|
||||
graph_data: &GraphData,
|
||||
) {
|
||||
if node.inputs.len() != 2 {
|
||||
panic!("MatMul node must have 2 inputs");
|
||||
|
@ -141,9 +141,10 @@ pub(crate) fn convert_matmul_to_linear(
|
|||
// Convert the node to Linear
|
||||
node.node_type = NodeType::Linear;
|
||||
|
||||
log::debug!("peeking next node for bias conversion");
|
||||
// Check the next node for potential conversion
|
||||
if let Some(peek_node) = iter_mut.peek() {
|
||||
let peek_node = convert_node_proto(peek_node, graph_io).clone();
|
||||
let peek_node = convert_node_proto(peek_node, graph_data);
|
||||
if is_add_node_with_bias(&peek_node, node) {
|
||||
convert_and_remove_add_node(&peek_node, node);
|
||||
|
||||
|
|
|
@ -4,14 +4,13 @@ use core::panic;
|
|||
use protobuf::Enum;
|
||||
|
||||
use super::{
|
||||
from_onnx::OnnxGraphIO,
|
||||
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
|
||||
op_configuration::flatten_config,
|
||||
protos::tensor_proto::DataType,
|
||||
};
|
||||
|
||||
/// Infer the dimension of each output tensor and update them.
|
||||
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
||||
pub fn dim_inference(node: &mut Node) {
|
||||
match node.node_type {
|
||||
NodeType::Add => same_as_input(node),
|
||||
NodeType::ArgMax => argmax_update_outputs(node),
|
||||
|
@ -81,8 +80,6 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
|
||||
_ => temporary_pass_through_stub(node),
|
||||
}
|
||||
|
||||
graph_io.update_tensor_output(node);
|
||||
}
|
||||
|
||||
fn constant_update_outputs(node: &mut Node) {
|
||||
|
|
|
@ -4,11 +4,12 @@ use std::{
|
|||
path::Path,
|
||||
};
|
||||
|
||||
use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto};
|
||||
use crate::onnx::node_remap::remap_node_type;
|
||||
|
||||
use super::{
|
||||
coalesce::coalesce,
|
||||
ir::{Data, OnnxGraph, TensorType},
|
||||
proto_conversion::convert_node_proto,
|
||||
protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto},
|
||||
};
|
||||
|
||||
|
@ -30,288 +31,213 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
|
|||
NodeType::Squeeze,
|
||||
];
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum IOEntry {
|
||||
In(usize),
|
||||
Out(usize),
|
||||
Node(usize),
|
||||
Node(usize, usize),
|
||||
}
|
||||
|
||||
pub(crate) struct OnnxGraphIO {
|
||||
/// The inputs for the Graph
|
||||
pub(crate) inputs: Vec<Argument>,
|
||||
/// The outputs for the Graph
|
||||
pub(crate) outputs: Vec<Argument>,
|
||||
/// Initializers
|
||||
pub struct GraphData {
|
||||
/// The nodes that have been processed, used to copy the outputs to a child node
|
||||
processed_nodes: Vec<Node>,
|
||||
/// The inputs of the graph
|
||||
inputs: Vec<Argument>,
|
||||
/// The outputs of the graph
|
||||
outputs: Vec<Argument>,
|
||||
/// The initializers of the graph
|
||||
pub(crate) initializers: HashMap<String, Argument>,
|
||||
///updated names of outputs of node not stored in the graph
|
||||
node_out: Vec<Argument>,
|
||||
pub(crate) old_io_names: HashMap<String, IOEntry>,
|
||||
/// Maps the original input name to a graph input
|
||||
input_name_map: HashMap<String, IOEntry>,
|
||||
/// Maps the updated input name to the original input name. Required to check if the input is an initializer
|
||||
input_key_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl OnnxGraphIO {
|
||||
impl GraphData {
|
||||
pub(crate) fn new(
|
||||
inputs: &Vec<ValueInfoProto>,
|
||||
outputs: &Vec<ValueInfoProto>,
|
||||
initializers: &Vec<TensorProto>,
|
||||
) -> Self {
|
||||
let mut old_io_names = HashMap::new();
|
||||
let mut in_count = 1;
|
||||
let mut input_name_map = HashMap::new();
|
||||
let mut input_key_map = HashMap::new();
|
||||
|
||||
let constants = initializers
|
||||
.iter()
|
||||
.map(|x| (x.name.clone(), Argument::from_initializer(x)))
|
||||
.collect::<HashMap<String, Argument>>();
|
||||
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||
.collect::<Vec<Argument>>();
|
||||
let inputs = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, x)| {
|
||||
let in_name = format!("input{}", in_count);
|
||||
old_io_names.insert(x.name.clone(), IOEntry::In(i));
|
||||
let in_name = format!("input{}", i + 1);
|
||||
|
||||
input_name_map.insert(x.name.clone(), IOEntry::In(i));
|
||||
input_key_map.insert(in_name.clone(), x.name.clone());
|
||||
|
||||
let mut arg = Argument::try_from(x.clone()).unwrap();
|
||||
if let Some(initial_arg) = constants.get(&x.name) {
|
||||
if arg.value.is_none() {
|
||||
log::warn!("Input {} is also an initializer. Initializer as default values are currently not supported", x.name);
|
||||
arg.copy_value(initial_arg);
|
||||
}
|
||||
}
|
||||
|
||||
in_count += 1;
|
||||
arg.name = in_name;
|
||||
arg
|
||||
})
|
||||
.collect::<Vec<Argument>>();
|
||||
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, x)| {
|
||||
old_io_names.insert(x.name.clone(), IOEntry::Out(i));
|
||||
Argument::try_from(x.clone()).unwrap()
|
||||
})
|
||||
.collect::<Vec<Argument>>();
|
||||
|
||||
let constants = initializers
|
||||
.iter()
|
||||
.map(|x| (x.name.clone(), Argument::from_initializer(x)))
|
||||
.collect::<HashMap<String, Argument>>();
|
||||
|
||||
Self {
|
||||
inputs,
|
||||
outputs,
|
||||
initializers: constants,
|
||||
node_out: Vec::new(),
|
||||
old_io_names,
|
||||
processed_nodes: Vec::new(),
|
||||
input_name_map,
|
||||
input_key_map,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_name(&mut self, arg: &Argument, new_name: &str) {
|
||||
match self.old_io_names.get(&arg.name) {
|
||||
Some(IOEntry::In(_)) => {
|
||||
panic!("input names are set from the beginning");
|
||||
}
|
||||
Some(IOEntry::Out(i)) => {
|
||||
let arg = self.outputs.get_mut(*i).unwrap();
|
||||
arg.name = new_name.to_string();
|
||||
}
|
||||
Some(IOEntry::Node(i)) => {
|
||||
let arg = self.node_out.get_mut(*i).unwrap();
|
||||
arg.name = new_name.to_string();
|
||||
}
|
||||
/// Get the value of an input from the original input name. Used during proto conversion
|
||||
pub(crate) fn init_in(&self, proto_str: &str) -> Argument {
|
||||
match self.input_name_map.get(proto_str) {
|
||||
None => {
|
||||
//Constants, Casts wound up here before API changes
|
||||
panic!(
|
||||
"Tried to update the name of {} to {} but entry doesn't exist in the map",
|
||||
arg.name, new_name
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to initialize the input arguments for nodes. Names need to remain the same because
|
||||
/// currently the old names are the key for accessing the Argument
|
||||
pub fn init_in(&self, proto_str: String) -> Argument {
|
||||
match self.old_io_names.get(&proto_str) {
|
||||
None => {
|
||||
if let Some(init_arg) = self.initializers.get(&proto_str) {
|
||||
//NOTE: if initializers are guaranteed to be unique, (I think they are
|
||||
//need to confirm) then we could pop the initializer from the map
|
||||
if let Some(init_arg) = self.initializers.get(proto_str) {
|
||||
init_arg.clone()
|
||||
} else {
|
||||
Argument::new(proto_str)
|
||||
log::warn!(
|
||||
"Input {} not found, should only happen when peeking",
|
||||
proto_str
|
||||
);
|
||||
Argument::new(proto_str.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
Some(IOEntry::In(i)) => {
|
||||
let mut arg = self.inputs[*i].clone();
|
||||
|
||||
arg.name = proto_str;
|
||||
arg.passed = true;
|
||||
arg
|
||||
}
|
||||
Some(IOEntry::Node(i)) => {
|
||||
let mut arg = self.node_out[*i].clone();
|
||||
arg.name = proto_str;
|
||||
arg
|
||||
}
|
||||
Some(IOEntry::Out(_)) => {
|
||||
panic!("graph output {} can't be a Node input", &proto_str)
|
||||
}
|
||||
Some(IOEntry::In(i)) => self.inputs[*i].clone(),
|
||||
Some(IOEntry::Node(i, j)) => self.processed_nodes[*i].outputs[*j].clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, arg: &Argument, new_name: &str) {
|
||||
if let Some(idx) = self.old_io_names.get(&arg.name) {
|
||||
if let IOEntry::Node(idx) = idx {
|
||||
if self.node_out[*idx].name == arg.name {
|
||||
self.node_out[*idx].name = new_name.to_string();
|
||||
return;
|
||||
/// Mark the graph_inputs to a node as passed, unless they are also initializers
|
||||
fn mark_input_passed(&mut self, node: &Node) {
|
||||
// we have to double map the inputs because the input might be replaced by an initializer
|
||||
node.inputs.iter().for_each(|node_input| {
|
||||
if let Some(old_input_name) = self.input_key_map.get(&node_input.name) {
|
||||
if !self.initializers.contains_key(old_input_name) {
|
||||
match self.input_name_map.get(old_input_name) {
|
||||
Some(IOEntry::In(i)) => self.inputs[*i].passed = true,
|
||||
_ => {
|
||||
panic!("Should not happen, please report this error");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("arg entry with old name {} is a graph IO", &arg.name);
|
||||
}
|
||||
}
|
||||
|
||||
let idx = self.node_out.len();
|
||||
self.old_io_names
|
||||
.insert(arg.name.clone(), IOEntry::Node(idx));
|
||||
self.node_out.push(arg.clone());
|
||||
self.node_out[idx].name = new_name.to_string();
|
||||
});
|
||||
}
|
||||
|
||||
/// Copies node outputs to graph IO. Used at the end of dim inference.
|
||||
pub(crate) fn update_tensor_output(&mut self, node: &Node) {
|
||||
for node_output in node.outputs.iter() {
|
||||
match self.old_io_names.get(&node_output.name) {
|
||||
Some(IOEntry::In(i)) => {
|
||||
let arg = self.inputs.get_mut(*i).unwrap();
|
||||
arg.copy_value(node_output);
|
||||
}
|
||||
Some(IOEntry::Out(i)) => {
|
||||
let arg = self.outputs.get_mut(*i).unwrap();
|
||||
arg.copy_value(node_output);
|
||||
//Set the output to passed since it's been altered by a Node
|
||||
arg.passed = true;
|
||||
}
|
||||
Some(IOEntry::Node(_)) => {
|
||||
panic!("This output is from another node");
|
||||
}
|
||||
None => {
|
||||
log::debug!("inserting with name {:?}", &node_output.name);
|
||||
let idx = self.node_out.len();
|
||||
self.old_io_names
|
||||
.insert(node_output.name.clone(), IOEntry::Node(idx));
|
||||
self.node_out.push(node_output.clone());
|
||||
}
|
||||
}
|
||||
/// This function does three things:
|
||||
/// 1. marks the inputs as passed
|
||||
/// 2. maps the old output names to the node output
|
||||
/// 3. renames the node output
|
||||
fn add_node(&mut self, mut node: Node) {
|
||||
log::debug!("adding node {:?}", &node.name);
|
||||
self.mark_input_passed(&node);
|
||||
let mut out_count = 1;
|
||||
for output in node.outputs.iter_mut() {
|
||||
self.input_name_map.insert(
|
||||
output.name.clone(),
|
||||
IOEntry::Node(self.processed_nodes.len(), 0),
|
||||
);
|
||||
output.name = format!("{}_out{}", node.name, out_count);
|
||||
out_count += 1;
|
||||
}
|
||||
self.processed_nodes.push(node);
|
||||
}
|
||||
|
||||
/// Used by handle unsqeeze to remap the output of a node to a new name
|
||||
/// expected match if it exists is either a graph input or graph output
|
||||
pub(crate) fn get_node_output(&self, old_name: &str) -> Option<&Argument> {
|
||||
match self.old_io_names.get(old_name) {
|
||||
Some(IOEntry::In(i)) => self.inputs.get(*i),
|
||||
Some(IOEntry::Out(i)) => self.outputs.get(*i),
|
||||
Some(IOEntry::Node(_)) => panic!("This is a node output"),
|
||||
None => None,
|
||||
}
|
||||
/// Consumes the graph data and returns the processed nodes, filtered inputs and outputs
|
||||
fn consume(mut self) -> (Vec<Node>, Vec<Argument>, Vec<Argument>) {
|
||||
self.inputs.retain(|x| x.passed);
|
||||
let outputs = self
|
||||
.outputs
|
||||
.into_iter()
|
||||
.filter_map(|x| match self.input_name_map.get(&x.name) {
|
||||
Some(IOEntry::Node(i, j)) => Some(self.processed_nodes[*i].outputs[*j].clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
(self.processed_nodes, self.inputs, outputs)
|
||||
}
|
||||
|
||||
/// Get the updated name of a Node Input, which should be
|
||||
/// either a graph input or a node output.
|
||||
/// Will return None if the it isn't a graph input or node output(like an initializer)
|
||||
/// Will panic if it's a graph output
|
||||
fn get_new_name(&mut self, old_name: &str) -> Option<String> {
|
||||
match self.old_io_names.get(old_name) {
|
||||
Some(IOEntry::In(i)) => {
|
||||
//FIXME: technically in the spec, initializers are default values
|
||||
//for optional inputs, but implementing that would require reworking
|
||||
//the way the graph is built, and it's not clear burn users are using initializers
|
||||
//in that way
|
||||
// see https://github.com/onnx/onnx/issues/2660
|
||||
if self.initializers.contains_key(old_name) {
|
||||
None
|
||||
} else {
|
||||
//set the input as passed since a node is referencing it
|
||||
self.inputs[*i].passed = true;
|
||||
Some(self.inputs[*i].name.clone())
|
||||
}
|
||||
}
|
||||
Some(IOEntry::Out(_)) => {
|
||||
panic!(
|
||||
"you just tried to get an updated name on a graph output: {}",
|
||||
old_name
|
||||
)
|
||||
}
|
||||
Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()),
|
||||
None => None,
|
||||
}
|
||||
/// Used to get the output of the graph by name. Only used to remap unsqueeze nodes
|
||||
pub fn get_graph_output(&self, name: &str) -> Option<&Argument> {
|
||||
self.outputs.iter().find(|x| x.name == name)
|
||||
}
|
||||
|
||||
// Since Nodes are added at the end of conversion, the current index is the length of the processed nodes
|
||||
/// Get the current index of the processed nodes. Useful when lifting values or marking nodes for removal
|
||||
pub fn get_current_index(&self) -> usize {
|
||||
self.processed_nodes.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct OnnxGraphBuilder {
|
||||
nodes: Vec<Node>,
|
||||
inputs: Vec<Argument>,
|
||||
outputs: Vec<Argument>,
|
||||
/// Counter for node names, used for renaming nodes
|
||||
node_name_counter: HashMap<NodeType, usize>,
|
||||
/// Nodes to remove
|
||||
/// Nodes to remove. Note may be moved to graph data if we implement support for custom ops
|
||||
nodes_to_remove: HashSet<usize>,
|
||||
/// Map from constant node output names to indices of constant nodes
|
||||
constants_map: HashMap<String, usize>,
|
||||
/// Node types that should be lifted to constants
|
||||
constants_types: HashSet<NodeType>,
|
||||
/// Map from identity node output names to indices of identity nodes
|
||||
identity_idx: HashMap<String, usize>,
|
||||
node_name_counter: HashMap<NodeType, usize>,
|
||||
}
|
||||
|
||||
impl OnnxGraphBuilder {
|
||||
pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) {
|
||||
pub(crate) fn build(mut self, model_proto: &ModelProto) -> OnnxGraph {
|
||||
self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect();
|
||||
|
||||
let mut graph_io = OnnxGraphIO::new(
|
||||
let mut graph_data = GraphData::new(
|
||||
&model_proto.graph.input,
|
||||
&model_proto.graph.output,
|
||||
&model_proto.graph.initializer,
|
||||
);
|
||||
|
||||
self.nodes = Vec::with_capacity(model_proto.graph.node.len());
|
||||
let mut and_idx = 0;
|
||||
let mut node_iter = model_proto.graph.node.iter().peekable();
|
||||
|
||||
while let Some(node_proto) = node_iter.next() {
|
||||
let mut node = convert_node_proto(node_proto, &graph_io);
|
||||
let mut node = convert_node_proto(node_proto, &graph_data);
|
||||
|
||||
remap_node_type(&mut node);
|
||||
|
||||
coalesce(&mut node, &mut node_iter, &graph_io);
|
||||
self.handle_node_renaming(&mut node);
|
||||
self.handle_identity(&mut node, and_idx);
|
||||
self.check_constants(&mut node, and_idx, &mut graph_io);
|
||||
self.handle_unsqueeze(&mut node, &graph_io);
|
||||
coalesce(&mut node, &mut node_iter, &graph_data);
|
||||
self.handle_identity(&mut node, &graph_data);
|
||||
self.check_constants(&mut node, &graph_data);
|
||||
// NOTE: potential start of custom functions
|
||||
// can filter, coalesce, or modify the nodes here
|
||||
// args : node, peek_iter, graph_data
|
||||
self.handle_unsqueeze(&mut node, &graph_data);
|
||||
|
||||
dim_inference(&mut node, &mut graph_io);
|
||||
|
||||
rename_io(&mut node, &mut graph_io);
|
||||
|
||||
self.nodes.push(node);
|
||||
and_idx += 1;
|
||||
dim_inference(&mut node);
|
||||
graph_data.add_node(node);
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
self.nodes.retain(|_x| {
|
||||
let res = !self.nodes_to_remove.contains(&i);
|
||||
i += 1;
|
||||
res
|
||||
});
|
||||
let OnnxGraphIO {
|
||||
mut inputs,
|
||||
mut outputs,
|
||||
..
|
||||
} = graph_io;
|
||||
|
||||
let (mut processed_nodes, inputs, outputs) = graph_data.consume();
|
||||
// Remove the graph inputs/output that are not used by any node
|
||||
remove_unused_graph_inputs(&mut inputs, &mut outputs);
|
||||
self.inputs = inputs;
|
||||
self.outputs = outputs;
|
||||
let mut i = 0;
|
||||
processed_nodes.retain(|_| {
|
||||
let keep = !self.nodes_to_remove.contains(&i);
|
||||
i += 1;
|
||||
keep
|
||||
});
|
||||
OnnxGraph {
|
||||
nodes: processed_nodes,
|
||||
inputs,
|
||||
outputs,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_node_renaming(&mut self, node: &mut Node) {
|
||||
|
@ -328,17 +254,20 @@ impl OnnxGraphBuilder {
|
|||
node.name.clone_from(&new_name);
|
||||
}
|
||||
|
||||
fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) {
|
||||
fn check_constants(&mut self, node: &mut Node, graph_data: &GraphData) {
|
||||
if node.node_type == NodeType::Constant
|
||||
|| (node.node_type == NodeType::Identity && node.inputs[0].value.is_some())
|
||||
{
|
||||
self.constants_map.insert(node.outputs[0].name.clone(), i);
|
||||
self.constants_map.insert(
|
||||
format!("{}_out{}", &node.name, 1),
|
||||
graph_data.get_current_index(),
|
||||
);
|
||||
} else if self.constants_types.contains(&node.node_type) {
|
||||
log::debug!("checking node {} for constants", &node.name);
|
||||
for input in node.inputs.iter_mut().skip(1) {
|
||||
log::debug!("checking input {:?} for const", input);
|
||||
if let Some(const_idx) = self.constants_map.get(&input.name) {
|
||||
let constant = &self.nodes[*const_idx];
|
||||
let constant = &graph_data.processed_nodes[*const_idx];
|
||||
log::debug!(
|
||||
"input {} matched constant node {}",
|
||||
&input.name,
|
||||
|
@ -362,29 +291,29 @@ impl OnnxGraphBuilder {
|
|||
/// Check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape
|
||||
/// Needs to be called after node renaming to ensure that the rhs name is correct
|
||||
/// Needs to be called after constant lifting to ensure that the rhs value exists
|
||||
fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) {
|
||||
fn handle_unsqueeze(&mut self, node: &mut Node, graph_data: &GraphData) {
|
||||
if node.node_type == NodeType::Unsqueeze
|
||||
&& node.inputs.len() > 1
|
||||
&& node.inputs[1].value.is_none()
|
||||
{
|
||||
if let Some(in_arg) = graph_io.get_node_output(&node.outputs[0].name) {
|
||||
remap_unsqueeze_to_reshape(node, in_arg);
|
||||
//if the output has a shape, it's only because it's a graph output
|
||||
if let Some(out_arg) = graph_data.get_graph_output(&node.outputs[0].name) {
|
||||
remap_unsqueeze_to_reshape(node, out_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_identity(&mut self, node: &mut Node, i: usize) {
|
||||
fn handle_identity(&mut self, node: &mut Node, graph_data: &GraphData) {
|
||||
if node.node_type == NodeType::Identity && node.inputs[0].value.is_none() {
|
||||
log::debug!("\nfound identity node:\n{:?}\n", &node);
|
||||
let i = graph_data.get_current_index();
|
||||
//map the output name to check for pass through values
|
||||
self.identity_idx.insert(node.outputs[0].name.clone(), i);
|
||||
self.identity_idx.insert(format!("{}_out1", &node.name), i);
|
||||
self.nodes_to_remove.insert(i);
|
||||
} else {
|
||||
//NOTE: it might be possible to rework the API to handle all "per input" operations
|
||||
//in a new function that operates on each input.
|
||||
node.inputs.iter_mut().for_each(|x| {
|
||||
if let Some(identity_idx) = self.identity_idx.get(&x.name) {
|
||||
let input_name = &self.nodes[*identity_idx].inputs[0].name;
|
||||
let input_name = &graph_data.processed_nodes[*identity_idx].inputs[0].name;
|
||||
|
||||
x.name.clone_from(input_name);
|
||||
}
|
||||
|
@ -431,118 +360,50 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph {
|
|||
);
|
||||
|
||||
log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len());
|
||||
let mut builder = OnnxGraphBuilder::default();
|
||||
builder.node_gen(&onnx_model);
|
||||
|
||||
let OnnxGraphBuilder {
|
||||
nodes,
|
||||
inputs: inner_inputs,
|
||||
outputs: inner_outputs,
|
||||
..
|
||||
} = builder;
|
||||
let builder = OnnxGraphBuilder::default();
|
||||
let graph = builder.build(&onnx_model);
|
||||
|
||||
log::info!("Finished parsing ONNX file: {}", onnx_path.display());
|
||||
|
||||
OnnxGraph {
|
||||
nodes,
|
||||
inputs: inner_inputs,
|
||||
outputs: inner_outputs,
|
||||
}
|
||||
graph
|
||||
}
|
||||
|
||||
/// Remap the unsqueeze node to a reshape node, Should only be called after
|
||||
/// node renaming has been done. avoids marking rhs as passed so that it can be
|
||||
/// properly deleted if nothing else uses it
|
||||
fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
|
||||
match node.outputs[0].ty {
|
||||
ArgType::Tensor(ref mut tensor_type) => {
|
||||
if let ArgType::Tensor(arg_tensor) = &out_arg.ty {
|
||||
tensor_type.shape.clone_from(&arg_tensor.shape);
|
||||
let inner = arg_tensor
|
||||
.shape
|
||||
.clone()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|x| x as i64)
|
||||
.collect::<Vec<i64>>();
|
||||
let shape_len = inner.len();
|
||||
let new_rhs_value = Some(Data::Int64s(inner));
|
||||
//moving the remap to here
|
||||
let rhs_arg = Argument {
|
||||
name: format!("{}_generated_const", node.name),
|
||||
ty: ArgType::Tensor(TensorType {
|
||||
elem_type: super::ir::ElementType::Int64,
|
||||
dim: 1,
|
||||
shape: Some(vec![shape_len]),
|
||||
}),
|
||||
value: new_rhs_value,
|
||||
passed: false,
|
||||
};
|
||||
node.inputs[1] = rhs_arg;
|
||||
node.outputs[0] = out_arg.clone();
|
||||
node.node_type = NodeType::Reshape;
|
||||
}
|
||||
/// Remap the unsqueeze node to a reshape node
|
||||
pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) {
|
||||
match &out_arg.ty {
|
||||
ArgType::Tensor(output_tensor) => {
|
||||
let inner = output_tensor
|
||||
.shape
|
||||
.clone()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|x| x as i64)
|
||||
.collect::<Vec<i64>>();
|
||||
let shape_len = inner.len();
|
||||
let new_rhs_value = Some(Data::Int64s(inner));
|
||||
//moving the remap to here
|
||||
let rhs_arg = Argument {
|
||||
name: format!("{}_generated_const", &node.name),
|
||||
ty: ArgType::Tensor(TensorType {
|
||||
elem_type: super::ir::ElementType::Int64,
|
||||
dim: 1,
|
||||
shape: Some(vec![shape_len]),
|
||||
}),
|
||||
value: new_rhs_value,
|
||||
passed: false,
|
||||
};
|
||||
// ? should this replace the old input (reuse the old key) or should it be a new key
|
||||
// going with new key for now
|
||||
node.inputs[1] = rhs_arg;
|
||||
node.outputs[0] = out_arg.clone();
|
||||
node.node_type = NodeType::Reshape;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rename the inputs and output in the graph and return a map of
|
||||
/// the old names to the new names.
|
||||
///
|
||||
/// The inputs are renamed to be unique and to be in the format of
|
||||
/// conv2_in1, conv2_in2, etc. This is done to be consistent with
|
||||
/// the naming convention of the nodes and allow to be used as rust identifiers.
|
||||
/// Rename the inputs and output in the graph and return a map of
|
||||
/// the old names to the new names.
|
||||
fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
||||
log::debug!("checking inputs for node {:?}", &node.name);
|
||||
for node_input in node.inputs.iter_mut() {
|
||||
if let Some(input_name) = graph_io.get_new_name(&node_input.name) {
|
||||
node_input.passed = true;
|
||||
node_input.name.clone_from(&input_name);
|
||||
} else {
|
||||
node_input.name = "".to_string();
|
||||
node_input.passed = false;
|
||||
}
|
||||
}
|
||||
let mut out_count = 1;
|
||||
if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity {
|
||||
let new_name = format!("{}_out{}", node.name, out_count);
|
||||
graph_io.insert(&node.outputs[0], &new_name);
|
||||
node.outputs[0].name.clone_from(&new_name);
|
||||
log::debug!("Found {} constant", new_name);
|
||||
} else {
|
||||
for output in node.outputs.iter_mut() {
|
||||
log::debug!("output name: {}", &output.name);
|
||||
|
||||
let new_name = format!("{}_out{}", node.name, out_count);
|
||||
|
||||
graph_io.update_name(output, &new_name);
|
||||
|
||||
output.name.clone_from(&new_name);
|
||||
out_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the graph inputs/output that are not used by any node.
|
||||
///
|
||||
/// In older ONNX models, the inputs and outputs are not always used by the nodes.
|
||||
/// For example, the input could be used as a state instead of an input. Since the
|
||||
/// inputs with initializers are moved to the states vector, the inputs vector could
|
||||
/// contain unused inputs. The same is true for the outputs.
|
||||
///
|
||||
/// Generally, it's a good idea to remove unused inputs/outputs because it makes the
|
||||
/// generated code cleaner and easier to read.
|
||||
fn remove_unused_graph_inputs(inputs: &mut Vec<Argument>, outputs: &mut Vec<Argument>) {
|
||||
// Remove inputs that are not used by any node
|
||||
inputs.retain(|input| input.passed);
|
||||
|
||||
// Remove outputs that are not used by any node
|
||||
outputs.retain(|output| output.passed);
|
||||
}
|
||||
|
||||
// Define a trait for topological sorting
|
||||
trait TopologicalSortable {
|
||||
fn is_top_sorted(&self) -> bool;
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::str::{from_utf8, FromStr};
|
|||
|
||||
use crate::onnx::ir::TensorType;
|
||||
|
||||
use super::from_onnx::OnnxGraphIO;
|
||||
use super::from_onnx::GraphData;
|
||||
use super::ir::Dim;
|
||||
use super::ir::{
|
||||
ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor,
|
||||
|
@ -180,19 +180,18 @@ pub fn convert_vec_attrs_proto(attrs: Vec<AttributeProto>) -> Attributes {
|
|||
result
|
||||
}
|
||||
|
||||
pub fn convert_node_proto(node: &NodeProto, graph_io: &OnnxGraphIO) -> Node {
|
||||
pub fn convert_node_proto(node: &NodeProto, graph_data: &GraphData) -> Node {
|
||||
let name = node.name.clone();
|
||||
|
||||
log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str());
|
||||
|
||||
let inputs = node
|
||||
.input
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|x| graph_io.init_in(x))
|
||||
.collect();
|
||||
let inputs = node.input.iter().map(|x| graph_data.init_in(x)).collect();
|
||||
|
||||
let outputs = node.output.clone().into_iter().map(Argument::new).collect();
|
||||
let outputs = node
|
||||
.output
|
||||
.iter()
|
||||
.map(|x| Argument::new(x.to_string()))
|
||||
.collect();
|
||||
|
||||
let attrs = convert_vec_attrs_proto(node.attribute.clone());
|
||||
|
||||
|
|
Loading…
Reference in New Issue