mirror of https://github.com/tracel-ai/burn.git
Tensor type indent fix (#2196)
* pad-input-fix: adding support for pads as attributes * final fix * undo pad changes
This commit is contained in:
parent
2c12d58cd8
commit
c94e743829
|
@ -556,13 +556,13 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
// Get the input and output types of the graph using passed in names
|
||||
input_names.iter().for_each(|input| {
|
||||
self.graph_input_types
|
||||
.push(inputs.get(input).unwrap().clone());
|
||||
.push(inputs.get(&TensorType::format_name(input)).unwrap().clone());
|
||||
});
|
||||
|
||||
output_names.iter().for_each(|output| {
|
||||
self.graph_output_types.push(
|
||||
outputs
|
||||
.get(output)
|
||||
.get(&TensorType::format_name(output))
|
||||
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
|
||||
.clone(),
|
||||
);
|
||||
|
|
|
@ -119,6 +119,17 @@ impl ShapeType {
|
|||
}
|
||||
|
||||
impl TensorType {
|
||||
// This is used, because Tensors might have number literal name, which cannot be
|
||||
// used as a variable name.
|
||||
pub fn format_name(name: &str) -> String {
|
||||
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
|
||||
if name_is_number {
|
||||
format!("_{}", name)
|
||||
} else {
|
||||
name.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new<S: AsRef<str>>(
|
||||
name: S,
|
||||
dim: usize,
|
||||
|
@ -131,8 +142,9 @@ impl TensorType {
|
|||
kind, shape
|
||||
);
|
||||
}
|
||||
let formatted_name = Self::format_name(name.as_ref());
|
||||
Self {
|
||||
name: Ident::new(name.as_ref(), Span::call_site()),
|
||||
name: Ident::new(&formatted_name, Span::call_site()),
|
||||
dim,
|
||||
kind,
|
||||
shape,
|
||||
|
|
Loading…
Reference in New Issue