mirror of https://github.com/tracel-ai/burn.git
Tensor type indent fix (#2196)
* pad-input-fix: adding support for pads as attributes * final fix * undo pad changes
This commit is contained in:
parent
2c12d58cd8
commit
c94e743829
|
@ -556,13 +556,13 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
||||||
// Get the input and output types of the graph using passed in names
|
// Get the input and output types of the graph using passed in names
|
||||||
input_names.iter().for_each(|input| {
|
input_names.iter().for_each(|input| {
|
||||||
self.graph_input_types
|
self.graph_input_types
|
||||||
.push(inputs.get(input).unwrap().clone());
|
.push(inputs.get(&TensorType::format_name(input)).unwrap().clone());
|
||||||
});
|
});
|
||||||
|
|
||||||
output_names.iter().for_each(|output| {
|
output_names.iter().for_each(|output| {
|
||||||
self.graph_output_types.push(
|
self.graph_output_types.push(
|
||||||
outputs
|
outputs
|
||||||
.get(output)
|
.get(&TensorType::format_name(output))
|
||||||
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
|
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
|
||||||
.clone(),
|
.clone(),
|
||||||
);
|
);
|
||||||
|
|
|
@ -119,6 +119,17 @@ impl ShapeType {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TensorType {
|
impl TensorType {
|
||||||
|
// This is used, because Tensors might have number literal name, which cannot be
|
||||||
|
// used as a variable name.
|
||||||
|
pub fn format_name(name: &str) -> String {
|
||||||
|
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
|
||||||
|
if name_is_number {
|
||||||
|
format!("_{}", name)
|
||||||
|
} else {
|
||||||
|
name.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new<S: AsRef<str>>(
|
pub fn new<S: AsRef<str>>(
|
||||||
name: S,
|
name: S,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
@ -131,8 +142,9 @@ impl TensorType {
|
||||||
kind, shape
|
kind, shape
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let formatted_name = Self::format_name(name.as_ref());
|
||||||
Self {
|
Self {
|
||||||
name: Ident::new(name.as_ref(), Span::call_site()),
|
name: Ident::new(&formatted_name, Span::call_site()),
|
||||||
dim,
|
dim,
|
||||||
kind,
|
kind,
|
||||||
shape,
|
shape,
|
||||||
|
|
Loading…
Reference in New Issue