Tensor type indent fix (#2196)

* pad-input-fix: adding support for pads as attributes

* final fix

* undo pad changes
This commit is contained in:
mepatrick73 2024-08-23 12:46:31 -04:00 committed by GitHub
parent 2c12d58cd8
commit c94e743829
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -556,13 +556,13 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
// Get the input and output types of the graph using passed in names // Get the input and output types of the graph using passed in names
input_names.iter().for_each(|input| { input_names.iter().for_each(|input| {
self.graph_input_types self.graph_input_types
.push(inputs.get(input).unwrap().clone()); .push(inputs.get(&TensorType::format_name(input)).unwrap().clone());
}); });
output_names.iter().for_each(|output| { output_names.iter().for_each(|output| {
self.graph_output_types.push( self.graph_output_types.push(
outputs outputs
.get(output) .get(&TensorType::format_name(output))
.unwrap_or_else(|| panic!("Output type is not found for {output}")) .unwrap_or_else(|| panic!("Output type is not found for {output}"))
.clone(), .clone(),
); );

View File

@ -119,6 +119,17 @@ impl ShapeType {
} }
impl TensorType { impl TensorType {
// This is used, because Tensors might have number literal name, which cannot be
// used as a variable name.
pub fn format_name(name: &str) -> String {
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
if name_is_number {
format!("_{}", name)
} else {
name.to_string()
}
}
pub fn new<S: AsRef<str>>( pub fn new<S: AsRef<str>>(
name: S, name: S,
dim: usize, dim: usize,
@ -131,8 +142,9 @@ impl TensorType {
kind, shape kind, shape
); );
} }
let formatted_name = Self::format_name(name.as_ref());
Self { Self {
name: Ident::new(name.as_ref(), Span::call_site()), name: Ident::new(&formatted_name, Span::call_site()),
dim, dim,
kind, kind,
shape, shape,