From c94e7438293b1ce441f75fbed8dd651ff1b54b92 Mon Sep 17 00:00:00 2001 From: mepatrick73 <114622680+mepatrick73@users.noreply.github.com> Date: Fri, 23 Aug 2024 12:46:31 -0400 Subject: [PATCH] Tensor type indent fix (#2196) * pad-input-fix: adding support for pads as attributes * final fix * undo pad changes --- crates/burn-import/src/burn/graph.rs | 4 ++-- crates/burn-import/src/burn/ty.rs | 14 +++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/crates/burn-import/src/burn/graph.rs b/crates/burn-import/src/burn/graph.rs index ee66399b3..92143f73d 100644 --- a/crates/burn-import/src/burn/graph.rs +++ b/crates/burn-import/src/burn/graph.rs @@ -556,13 +556,13 @@ impl BurnGraph { // Get the input and output types of the graph using passed in names input_names.iter().for_each(|input| { self.graph_input_types - .push(inputs.get(input).unwrap().clone()); + .push(inputs.get(&TensorType::format_name(input)).unwrap().clone()); }); output_names.iter().for_each(|output| { self.graph_output_types.push( outputs - .get(output) + .get(&TensorType::format_name(output)) .unwrap_or_else(|| panic!("Output type is not found for {output}")) .clone(), ); diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index 9bf04d3d0..191333d1a 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -119,6 +119,17 @@ impl ShapeType { } impl TensorType { + // This is used, because Tensors might have number literal name, which cannot be + // used as a variable name. + pub fn format_name(name: &str) -> String { + let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit()); + if name_is_number { + format!("_{}", name) + } else { + name.to_string() + } + } + pub fn new>( name: S, dim: usize, @@ -131,8 +142,9 @@ impl TensorType { kind, shape ); } + let formatted_name = Self::format_name(name.as_ref()); Self { - name: Ident::new(name.as_ref(), Span::call_site()), + name: Ident::new(&formatted_name, Span::call_site()), dim, kind, shape,