Tensor type indent fix (#2196)

* pad-input-fix: adding support for pads as attributes

* final fix

* undo pad changes
This commit is contained in:
mepatrick73 2024-08-23 12:46:31 -04:00 committed by GitHub
parent 2c12d58cd8
commit c94e743829
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -556,13 +556,13 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
// Get the input and output types of the graph using passed in names
input_names.iter().for_each(|input| {
self.graph_input_types
.push(inputs.get(input).unwrap().clone());
.push(inputs.get(&TensorType::format_name(input)).unwrap().clone());
});
output_names.iter().for_each(|output| {
self.graph_output_types.push(
outputs
.get(output)
.get(&TensorType::format_name(output))
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
.clone(),
);

View File

@ -119,6 +119,17 @@ impl ShapeType {
}
impl TensorType {
// This is used, because Tensors might have number literal name, which cannot be
// used as a variable name.
pub fn format_name(name: &str) -> String {
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
if name_is_number {
format!("_{}", name)
} else {
name.to_string()
}
}
pub fn new<S: AsRef<str>>(
name: S,
dim: usize,
@ -131,8 +142,9 @@ impl TensorType {
kind, shape
);
}
let formatted_name = Self::format_name(name.as_ref());
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
name: Ident::new(&formatted_name, Span::call_site()),
dim,
kind,
shape,