mirror of https://github.com/tracel-ai/burn.git
Print model structure like with PyTorch - Part 1 (#1912)
This commit is contained in:
parent
3faf544bc4
commit
2c51615471
|
@ -463,6 +463,7 @@ version = "0.14.0"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"dashmap",
|
||||
"data-encoding",
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
|
@ -471,7 +472,6 @@ dependencies = [
|
|||
"serde",
|
||||
"spin",
|
||||
"tokio",
|
||||
"uuid",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
|
@ -1469,6 +1469,12 @@ dependencies = [
|
|||
"parking_lot_core 0.9.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
|
||||
|
||||
[[package]]
|
||||
name = "deflate64"
|
||||
version = "0.1.8"
|
||||
|
|
13
Cargo.toml
13
Cargo.toml
|
@ -34,6 +34,9 @@ colored = "2.1.0"
|
|||
console_error_panic_hook = "0.1.7"
|
||||
csv = "1.3.0"
|
||||
dashmap = "5.5.3"
|
||||
data-encoding = { version = "2.6.0", default-features = false, features = [
|
||||
"alloc",
|
||||
] }
|
||||
dirs = "5.0.1"
|
||||
fake = "2.9.2"
|
||||
flate2 = "1.0.30"
|
||||
|
@ -42,16 +45,19 @@ getrandom = { version = "0.2.15", default-features = false }
|
|||
gix-tempfile = { version = "13.1.1", features = ["signals"] }
|
||||
globwalk = "0.9.1"
|
||||
hashbrown = "0.14.5"
|
||||
hound = "3.5.1"
|
||||
image = "0.25.1"
|
||||
indicatif = "0.17.8"
|
||||
js-sys = "0.3.69"
|
||||
libm = "0.2.8"
|
||||
log = { default-features = false, version = "0.4.21" }
|
||||
md5 = "0.7.0"
|
||||
percent-encoding = "2.3.1"
|
||||
pretty_assertions = "1.4.0"
|
||||
proc-macro2 = "1.0.85"
|
||||
protobuf = "3.4.0"
|
||||
protobuf-codegen = "3.4.0"
|
||||
quote = "1.0.36"
|
||||
percent-encoding = "2.3.1"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = { version = "0.24.0" }
|
||||
rayon = "1.10.0"
|
||||
|
@ -63,6 +69,7 @@ rusqlite = { version = "0.31.0" }
|
|||
rust-format = { version = "0.3.4" }
|
||||
sanitize-filename = "0.5.0"
|
||||
serde_rusqlite = "0.35.0"
|
||||
serial_test = "3.1.1"
|
||||
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
|
||||
strum = "0.26.2"
|
||||
strum_macros = "0.26.4"
|
||||
|
@ -73,11 +80,7 @@ tokio = { version = "1.38.0", features = ["rt", "macros"] }
|
|||
tracing-appender = "0.2.3"
|
||||
tracing-core = "0.1.32"
|
||||
tracing-subscriber = "0.3.18"
|
||||
md5 = "0.7.0"
|
||||
serial_test = "3.1.1"
|
||||
web-time = "1.1.0"
|
||||
hound = "3.5.1"
|
||||
image = "0.25.1"
|
||||
zip = "2.1.3"
|
||||
|
||||
# Terminal UI
|
||||
|
|
|
@ -12,7 +12,7 @@ version.workspace = true
|
|||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["rand/std"]
|
||||
std = ["rand/std", "data-encoding/std"]
|
||||
doc = ["default"]
|
||||
wasm-sync = []
|
||||
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
|
@ -27,10 +27,10 @@ web-time = { version = "1.1.0" }
|
|||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
uuid = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
derive-new = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
data-encoding = { workspace = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
use alloc::string::String;
|
||||
|
||||
use crate::rand::gen_random;
|
||||
use alloc::string::{String, ToString};
|
||||
use uuid::{Builder, Bytes};
|
||||
|
||||
use data_encoding::BASE32_DNSSEC;
|
||||
|
||||
/// Simple ID generator.
|
||||
pub struct IdGenerator {}
|
||||
|
||||
impl IdGenerator {
|
||||
/// Generates a new ID in the form of a UUID.
|
||||
/// Generates a new ID.
|
||||
pub fn generate() -> String {
|
||||
let random_bytes: Bytes = gen_random();
|
||||
// Generate 6 random bytes (281,474,976,710,656 combinations)
|
||||
let random_bytes: [u8; 6] = gen_random();
|
||||
|
||||
let uuid = Builder::from_random_bytes(random_bytes).into_uuid();
|
||||
|
||||
uuid.as_hyphenated().to_string()
|
||||
// Encode the random bytes in base32 DNSSEC
|
||||
// 6 bytes encodes to 10 lower case characters, e.g. "3uu5e6vv7c"
|
||||
BASE32_DNSSEC.encode(&random_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,547 @@
|
|||
use alloc::{
|
||||
borrow::ToOwned,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
use core::any;
|
||||
use core::fmt::{Display, Write};
|
||||
|
||||
/// Default display settings for a module.
|
||||
pub trait ModuleDisplayDefault {
|
||||
/// Attributes of the module used for display purposes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `_content` - The content object that contains display settings and attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional content object containing the display attributes.
|
||||
fn content(&self, _content: Content) -> Option<Content>;
|
||||
|
||||
/// Gets the number of the parameters of the module.
|
||||
fn num_params(&self) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to implement custom display settings for a module.
|
||||
///
|
||||
/// In order to implement custom display settings for a module,
|
||||
/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
|
||||
/// 2. Implement ModuleDisplay trait for the module
|
||||
pub trait ModuleDisplay: ModuleDisplayDefault {
|
||||
/// Formats the module with provided display settings.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `passed_settings` - Display settings passed to the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A string representation of the formatted module.
|
||||
fn format(&self, passed_settings: DisplaySettings) -> String {
|
||||
let settings = if let Some(custom_settings) = self.custom_settings() {
|
||||
custom_settings.inherit(passed_settings)
|
||||
} else {
|
||||
passed_settings
|
||||
};
|
||||
|
||||
let indent = " ".repeat(settings.level * settings.indentation_size());
|
||||
let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
|
||||
|
||||
let settings = settings.level_up();
|
||||
|
||||
let self_type = extract_type_name::<Self>();
|
||||
|
||||
// Use custom content if it is implemented and show_all_attributes is false,
|
||||
// otherwise use default content
|
||||
let content = if !settings.show_all_attributes() {
|
||||
self.custom_content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| {
|
||||
self.content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| {
|
||||
panic!("Default content should be implemented for {self_type}.")
|
||||
})
|
||||
})
|
||||
} else {
|
||||
self.content(Content::new(settings.clone()))
|
||||
.unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
|
||||
};
|
||||
|
||||
let top_level_type = if let Some(top_level_type) = content.top_level_type {
|
||||
top_level_type.to_owned()
|
||||
} else {
|
||||
self_type.to_owned()
|
||||
};
|
||||
|
||||
// If there is only one item in the content, return it or no attributes
|
||||
if let Some(item) = content.single_item {
|
||||
return item;
|
||||
} else if content.attributes.is_empty() {
|
||||
return top_level_type.to_string();
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
|
||||
// Print the struct name
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{} {{", top_level_type).unwrap();
|
||||
} else {
|
||||
write!(result, "{} {{", top_level_type).unwrap();
|
||||
}
|
||||
|
||||
for (i, attribute) in content.attributes.iter().enumerate() {
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
|
||||
} else if i == 0 {
|
||||
write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
|
||||
} else {
|
||||
write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if settings.show_num_parameters() {
|
||||
let num_params = self.num_params();
|
||||
if num_params > 0 {
|
||||
if settings.new_line_after_attribute() {
|
||||
writeln!(result, "{indent}params: {}", num_params).unwrap();
|
||||
} else {
|
||||
write!(result, ", params: {}", num_params).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if settings.new_line_after_attribute() {
|
||||
write!(result, "{indent_close_braces}}}").unwrap();
|
||||
} else {
|
||||
write!(result, "}}").unwrap();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Custom display settings for the module.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional display settings object.
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Custom attributes for the module.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `_content` - The content object that contains display settings and attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional content object containing the custom attributes.
|
||||
fn custom_content(&self, _content: Content) -> Option<Content> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom module display settings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DisplaySettings {
|
||||
/// Whether to print the module parameter ids.
|
||||
show_param_id: Option<bool>,
|
||||
|
||||
/// Whether to print the module attributes.
|
||||
show_all_attributes: Option<bool>,
|
||||
|
||||
/// Whether to print the module number of parameters.
|
||||
show_num_parameters: Option<bool>,
|
||||
|
||||
/// Print new line after an attribute.
|
||||
new_line_after_attribute: Option<bool>,
|
||||
|
||||
/// Indentation size.
|
||||
indentation_size: Option<usize>,
|
||||
|
||||
/// Level of indentation.
|
||||
level: usize,
|
||||
}
|
||||
|
||||
impl Default for DisplaySettings {
|
||||
fn default() -> Self {
|
||||
DisplaySettings {
|
||||
show_param_id: None,
|
||||
show_all_attributes: None,
|
||||
show_num_parameters: None,
|
||||
new_line_after_attribute: None,
|
||||
indentation_size: None,
|
||||
level: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplaySettings {
|
||||
/// Create a new format settings.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new instance of `DisplaySettings`.
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// Sets a flag to show module parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_param_id(mut self, flag: bool) -> Self {
|
||||
self.show_param_id = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to show module attributes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show all module attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
|
||||
self.show_all_attributes = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to show the number of module parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to show the number of module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
|
||||
self.show_num_parameters = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a flag to print a new line after an attribute.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `flag` - Boolean flag to print a new line after an attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
|
||||
self.new_line_after_attribute = Some(flag);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the indentation size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn with_indentation_size(mut self, size: usize) -> Self {
|
||||
self.indentation_size = Some(size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Inherits settings from the provided settings and return a new settings object.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `top` - The top level `DisplaySettings` to inherit from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance.
|
||||
pub fn inherit(self, top: Self) -> Self {
|
||||
let mut updated = self.clone();
|
||||
|
||||
if let Some(show_param_id) = top.show_param_id {
|
||||
updated.show_param_id = Some(show_param_id);
|
||||
};
|
||||
|
||||
if let Some(show_all_attributes) = top.show_all_attributes {
|
||||
updated.show_all_attributes = Some(show_all_attributes);
|
||||
}
|
||||
|
||||
if let Some(show_num_parameters) = top.show_num_parameters {
|
||||
updated.show_num_parameters = Some(show_num_parameters);
|
||||
}
|
||||
|
||||
if let Some(new_line_after_attribute) = top.new_line_after_attribute {
|
||||
updated.new_line_after_attribute = Some(new_line_after_attribute);
|
||||
}
|
||||
|
||||
if let Some(indentation_size) = top.indentation_size {
|
||||
updated.indentation_size = Some(indentation_size);
|
||||
}
|
||||
|
||||
updated.level = top.level;
|
||||
|
||||
updated
|
||||
}
|
||||
|
||||
/// A convenience method to wrap the DisplaySettings struct in an option.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional `DisplaySettings`.
|
||||
pub fn optional(self) -> Option<Self> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
/// Increases the level of indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `DisplaySettings` instance with increased indentation level.
|
||||
pub fn level_up(mut self) -> Self {
|
||||
self.level += 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets `show_param_id` flag, substitutes false if not set.
|
||||
///
|
||||
/// This flag is used to print the module parameter ids.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show parameter ids.
|
||||
pub fn show_param_id(&self) -> bool {
|
||||
self.show_param_id.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Gets `show_all_attributes`, substitutes false if not set.
|
||||
///
|
||||
/// This flag is used to force to print all module attributes, overriding custom attributes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show all attributes.
|
||||
pub fn show_all_attributes(&self) -> bool {
|
||||
self.show_all_attributes.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Gets `show_num_parameters`, substitutes true if not set.
|
||||
///
|
||||
/// This flag is used to print the number of module parameters.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to show the number of parameters.
|
||||
pub fn show_num_parameters(&self) -> bool {
|
||||
self.show_num_parameters.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Gets `new_line_after_attribute`, substitutes true if not set.
|
||||
///
|
||||
/// This flag is used to print a new line after an attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean value indicating whether to print a new line after an attribute.
|
||||
pub fn new_line_after_attribute(&self) -> bool {
|
||||
self.new_line_after_attribute.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Gets `indentation_size`, substitutes 2 if not set.
|
||||
///
|
||||
/// This flag is used to set the size of indentation.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An integer value indicating the size of indentation.
|
||||
pub fn indentation_size(&self) -> usize {
|
||||
self.indentation_size.unwrap_or(2)
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct to store the attributes of a module for formatting.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Content {
|
||||
/// List of attributes.
|
||||
pub attributes: Vec<Attribute>,
|
||||
|
||||
/// Single item content.
|
||||
pub single_item: Option<String>,
|
||||
|
||||
/// Display settings.
|
||||
pub display_settings: DisplaySettings,
|
||||
|
||||
/// Top level type name.
|
||||
pub top_level_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
/// Creates a new attributes struct.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `display_settings` - Display settings for the content.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new instance of `Content`.
|
||||
pub fn new(display_settings: DisplaySettings) -> Self {
|
||||
Content {
|
||||
attributes: Vec::new(),
|
||||
single_item: None,
|
||||
display_settings,
|
||||
top_level_type: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds an attribute to the format settings. The value will be formatted and stored as a string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Name of the attribute.
|
||||
/// * `value` - Value of the attribute.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the new attribute added.
|
||||
pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
|
||||
if self.single_item.is_some() {
|
||||
panic!("Cannot add multiple attributes when single item is set.");
|
||||
}
|
||||
|
||||
let attribute = Attribute {
|
||||
name: name.to_owned(),
|
||||
value: value.format(self.display_settings.clone()), // TODO level + 1
|
||||
ty: any::type_name::<T>().to_string(),
|
||||
};
|
||||
self.attributes.push(attribute);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a single item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - Rendered string of the single item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the single item added.
|
||||
pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
|
||||
if !self.attributes.is_empty() {
|
||||
panic!("Cannot add single item when attributes are set.");
|
||||
}
|
||||
|
||||
self.single_item = Some(value.format(self.display_settings.clone()));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a single item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - Formatted display value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the formatted single item added.
|
||||
pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
|
||||
if !self.attributes.is_empty() {
|
||||
panic!("Cannot add single item when attributes are set.");
|
||||
}
|
||||
|
||||
self.single_item = Some(format!("{}", value));
|
||||
self
|
||||
}
|
||||
|
||||
/// A convenience method to wrap the Attributes struct in an option
|
||||
/// because it is often used as an optional field.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional `Content`.
|
||||
pub fn optional(self) -> Option<Self> {
|
||||
if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the top level type name.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ty` - The type name to set.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Updated `Content` instance with the top level type name set.
|
||||
pub fn set_top_level_type(mut self, ty: &str) -> Self {
|
||||
self.top_level_type = Some(ty.to_owned());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Attribute to print in the display method.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Attribute {
|
||||
/// Name of the attribute.
|
||||
pub name: String,
|
||||
|
||||
/// Value of the attribute.
|
||||
pub value: String,
|
||||
|
||||
/// Type of the attribute.
|
||||
pub ty: String,
|
||||
}
|
||||
|
||||
/// Extracts the short name of a type T
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A string slice representing the short name of the type.
|
||||
pub fn extract_type_name<T: ?Sized>() -> &'static str {
|
||||
// Get the full type name of T, including module path and generic parameters
|
||||
let ty = any::type_name::<T>();
|
||||
|
||||
// Find the first occurrence of '<' in the full type name
|
||||
// If not found, use the length of the type name
|
||||
let end = ty.find('<').unwrap_or(ty.len());
|
||||
|
||||
// Slice the type name up to the first '<' or the end
|
||||
let ty = &ty[0..end];
|
||||
|
||||
// Find the last occurrence of "::" in the sliced type name
|
||||
// If found, add 2 to skip the "::" itself
|
||||
// If not found, start from the beginning of the type name
|
||||
let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
|
||||
|
||||
// Find the last occurrence of '<' in the sliced type name
|
||||
// If not found, use the length of the type name
|
||||
let end = ty.rfind('<').unwrap_or(ty.len());
|
||||
|
||||
// If the start index is less than the end index,
|
||||
// return the slice of the type name from start to end
|
||||
// Otherwise, return the entire sliced type name
|
||||
if start < end {
|
||||
&ty[start..end]
|
||||
} else {
|
||||
ty
|
||||
}
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
mod base;
|
||||
mod display;
|
||||
mod param;
|
||||
|
||||
pub use base::*;
|
||||
pub use display::*;
|
||||
pub use param::*;
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
use alloc::{format, string::ToString};
|
||||
use core::{fmt::Display, marker::PhantomData};
|
||||
|
||||
use crate::{
|
||||
self as burn,
|
||||
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
|
||||
module::{
|
||||
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
|
||||
ModuleMapper, ModuleVisitor,
|
||||
},
|
||||
record::Record,
|
||||
};
|
||||
use burn::record::PrecisionSettings;
|
||||
|
@ -8,7 +14,6 @@ use burn_tensor::{
|
|||
backend::{AutodiffBackend, Backend},
|
||||
BasicAutodiffOps, BasicOps, Tensor,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Record used for constant type implementing the [module](crate::module::Module) trait.
|
||||
#[derive(Debug, Clone, Copy, new, Default)]
|
||||
|
@ -96,6 +101,15 @@ macro_rules! constant {
|
|||
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
|
||||
constant!(ad_module, $type);
|
||||
}
|
||||
|
||||
impl burn::module::ModuleDisplayDefault for $type {
|
||||
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
let string = format!("{}", self);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl burn::module::ModuleDisplay for $type {}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -122,6 +136,13 @@ constant!(i32);
|
|||
constant!(i16);
|
||||
constant!(i8);
|
||||
|
||||
impl burn::module::ModuleDisplay for str {}
|
||||
impl burn::module::ModuleDisplayDefault for str {
|
||||
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
content.add_formatted(&self).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
||||
type Record = ConstantRecord;
|
||||
|
||||
|
@ -158,6 +179,15 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
|
||||
content.add_single(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
|
||||
for Tensor<B, D, K>
|
||||
{
|
||||
|
@ -200,6 +230,14 @@ impl<B: Backend> Module<B> for PhantomData<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
content.add_single(&"PhantomData".to_string()).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PhantomData<B> {}
|
||||
|
||||
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
|
||||
type InnerModule = PhantomData<B::InnerBackend>;
|
||||
|
||||
|
@ -248,6 +286,27 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> ModuleDisplayDefault for Ignored<T>
|
||||
where
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
// For now, just print the debug representation of the ignored value
|
||||
content.add_single(&format!("{:?}", self.0)).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
|
||||
|
||||
impl<T> Display for Ignored<T>
|
||||
where
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "{:?}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
|
||||
use alloc::vec::Vec;
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor,
|
||||
};
|
||||
|
||||
use alloc::{format, vec::Vec};
|
||||
|
||||
use burn_tensor::backend::{AutodiffBackend, Backend};
|
||||
use core::fmt::Debug;
|
||||
|
||||
|
@ -52,6 +57,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
match self {
|
||||
Some(module) => content.add_single(module).optional(),
|
||||
None => content.add_single("None").optional(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Option<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
|
@ -128,6 +144,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
self.iter()
|
||||
.enumerate()
|
||||
.fold(content, |acc, (i, module)| {
|
||||
let index = format!("{}", i);
|
||||
acc.add(&index, module)
|
||||
})
|
||||
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
|
||||
|
||||
impl<T, B> AutodiffModule<B> for Vec<T>
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone,
|
||||
|
@ -197,6 +228,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
self.iter()
|
||||
.enumerate()
|
||||
.fold(content, |acc, (i, module)| {
|
||||
let index = format!("{}", i);
|
||||
acc.add(&index, module)
|
||||
})
|
||||
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
|
||||
|
||||
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
|
||||
where
|
||||
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
|
||||
|
@ -269,6 +315,21 @@ macro_rules! impl_module_tuple {
|
|||
($(self.$i.valid(),)*)
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
|
||||
where
|
||||
$($l: ModuleDisplay,)*
|
||||
{
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let content = content
|
||||
$(.add(&format!("{}", $i), &self.$i))*
|
||||
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
|
||||
content.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
use super::ParamId;
|
||||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor, Param,
|
||||
};
|
||||
|
||||
use alloc::string::ToString;
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
|
@ -45,6 +51,24 @@ pub struct RunningState<V> {
|
|||
value: Arc<Mutex<V>>,
|
||||
}
|
||||
|
||||
// Implement display for the module
|
||||
|
||||
impl<V> core::fmt::Display for RunningState<V> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "RunningState(id={})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> ModuleDisplayDefault for RunningState<V> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add_formatted(&"RunningState".to_string())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> ModuleDisplay for RunningState<V> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
|
||||
type Record = Param<Tensor<B, D>>;
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use super::{Param, ParamId, Parameter};
|
||||
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
|
||||
use crate::module::{
|
||||
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
|
||||
ModuleVisitor,
|
||||
};
|
||||
use crate::tensor::{
|
||||
backend::{AutodiffBackend, Backend},
|
||||
Tensor,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{format, string::ToString, vec::Vec};
|
||||
use burn_tensor::{Bool, Data, Float, Int};
|
||||
|
||||
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
|
||||
|
@ -147,6 +150,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
|
||||
self.shape().dims
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
||||
type Record = Param<Tensor<B, D, Int>>;
|
||||
|
||||
|
@ -198,6 +217,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
|
||||
self.shape().dims
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
||||
type Record = Param<Tensor<B, D, Bool>>;
|
||||
|
||||
|
@ -249,6 +284,24 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
|
||||
fn content(&self, content: Content) -> Option<Content> {
|
||||
let id = if content.display_settings.show_param_id() {
|
||||
format!(", id: {}", self.id)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let string = format!(
|
||||
"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
|
||||
self.shape().dims
|
||||
);
|
||||
content.add_formatted(&string).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
|
||||
|
||||
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
use alloc::format;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::{Initializer, PaddingConfig1d};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::conv1d;
|
||||
use crate::tensor::ops::ConvOptions;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param},
|
||||
nn::{conv::checks, Initializer, PaddingConfig1d},
|
||||
tensor::{backend::Backend, module::conv1d, ops::ConvOptions, Tensor},
|
||||
};
|
||||
|
||||
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
|
@ -45,6 +44,7 @@ pub struct Conv1dConfig {
|
|||
///
|
||||
/// Should be created with [Conv1dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv1d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
|
||||
pub weight: Param<Tensor<B, 3>>,
|
||||
|
@ -54,7 +54,28 @@ pub struct Conv1d<B: Backend> {
|
|||
kernel_size: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
padding: PaddingConfig1d,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv1d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
content
|
||||
.add("stride", &self.stride)
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("dilation", &self.dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Conv1dConfig {
|
||||
|
@ -87,7 +108,7 @@ impl Conv1dConfig {
|
|||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
groups: self.groups,
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use alloc::format;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -45,6 +46,7 @@ pub struct Conv2dConfig {
|
|||
///
|
||||
/// Should be created with [Conv2dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Conv2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
|
@ -54,7 +56,7 @@ pub struct Conv2d<B: Backend> {
|
|||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
padding: PaddingConfig2d,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
impl Conv2dConfig {
|
||||
|
@ -93,12 +95,38 @@ impl Conv2dConfig {
|
|||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
groups: self.groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
|
||||
content
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Conv2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, Tensor};
|
||||
|
||||
|
@ -21,6 +21,7 @@ pub struct DropoutConfig {
|
|||
///
|
||||
/// Should be created with [DropoutConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Dropout {
|
||||
prob: f64,
|
||||
}
|
||||
|
@ -54,6 +55,18 @@ impl Dropout {
|
|||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Dropout {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
content.add("prob", &self.prob).optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate as burn;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::ModuleDisplay;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
|
@ -30,6 +32,7 @@ pub struct LinearConfig {
|
|||
///
|
||||
/// `O = IW + b`
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Linear<B: Backend> {
|
||||
/// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
|
||||
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
|
||||
|
@ -83,6 +86,23 @@ impl<B: Backend> Linear<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Linear<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
let [d_input, d_output] = self.weight.shape().dims;
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_output", &d_output)
|
||||
.add("bias", &self.bias.is_some())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate as burn;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
|
||||
use crate::nn::Initializer;
|
||||
use crate::{
|
||||
|
@ -33,6 +34,7 @@ pub struct BatchNormConfig {
|
|||
///
|
||||
/// Should be created using [BatchNormConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BatchNorm<B: Backend, const D: usize> {
|
||||
/// The learnable weight gamma.
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
|
@ -183,6 +185,24 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [num_features] = self.beta.shape().dims;
|
||||
|
||||
content
|
||||
.add("num_features", &num_features)
|
||||
.add("momentum", &self.momentum)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[cfg(test)]
|
||||
mod tests_1d {
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::Module;
|
||||
use crate::module::ModuleDisplay;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -29,6 +30,7 @@ pub struct LayerNormConfig {
|
|||
///
|
||||
/// Should be created using [LayerNormConfig](LayerNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct LayerNorm<B: Backend> {
|
||||
/// The learnable weight.
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
|
@ -71,6 +73,22 @@ impl<B: Backend> LayerNorm<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for LayerNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
let [d_model] = self.gamma.shape().dims;
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -3,10 +3,9 @@ use crate as burn;
|
|||
use crate::tensor::ops::conv::calculate_conv_padding;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
|
||||
/// Padding configuration for 1D operators.
|
||||
#[derive(Module, Config, Debug, PartialEq)]
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig1d {
|
||||
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
|
||||
/// the same as the input.
|
||||
|
@ -34,7 +33,7 @@ impl PaddingConfig1d {
|
|||
}
|
||||
|
||||
/// Padding configuration for 2D operators.
|
||||
#[derive(Module, Config, Debug, PartialEq)]
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum PaddingConfig2d {
|
||||
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
|
||||
/// the same as the input.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -43,7 +43,7 @@ pub struct AvgPool1dConfig {
|
|||
pub struct AvgPool1d {
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: PaddingConfig1d,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ impl AvgPool1dConfig {
|
|||
AvgPool1d {
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
count_include_pad: self.count_include_pad,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -42,7 +42,7 @@ pub struct AvgPool2dConfig {
|
|||
pub struct AvgPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: PaddingConfig2d,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
count_include_pad: bool,
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ impl AvgPool2dConfig {
|
|||
AvgPool2d {
|
||||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
count_include_pad: self.count_include_pad,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -31,7 +31,7 @@ pub struct MaxPool1dConfig {
|
|||
pub struct MaxPool1d {
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: PaddingConfig1d,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
dilation: usize,
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ impl MaxPool1dConfig {
|
|||
MaxPool1d {
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -31,7 +31,7 @@ pub struct MaxPool2dConfig {
|
|||
pub struct MaxPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: PaddingConfig2d,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
dilation: [usize; 2],
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ impl MaxPool2dConfig {
|
|||
MaxPool2d {
|
||||
stride: self.strides,
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
padding: Ignored(self.padding.clone()),
|
||||
dilation: self.dilation,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::UnfoldOptions;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use crate::tensor::module::unfold4d;
|
||||
use crate::module::{Ignored, Module};
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::module::unfold4d;
|
||||
use burn_tensor::ops::UnfoldOptions;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
|
@ -29,14 +28,14 @@ pub struct Unfold4dConfig {
|
|||
/// Should be created with [Unfold4dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Unfold4d {
|
||||
config: Unfold4dConfig,
|
||||
config: Ignored<Unfold4dConfig>,
|
||||
}
|
||||
|
||||
impl Unfold4dConfig {
|
||||
/// Initializes a new [Unfold4d] module.
|
||||
pub fn init(&self) -> Unfold4d {
|
||||
Unfold4d {
|
||||
config: self.clone(),
|
||||
config: Ignored(self.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +47,7 @@ impl Unfold4d {
|
|||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// input: `[batch_size, channels_in, height, width]`
|
||||
/// input: `[batch_size, channels_in, height, width]`
|
||||
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
|
||||
unfold4d(
|
||||
|
|
|
@ -370,6 +370,6 @@ mod tests {
|
|||
|
||||
// Compare the lengths of expected and actual serialized strings because
|
||||
// the order of the fields is not guaranteed for HashMaps.
|
||||
assert_eq!(serialized_str.len(), 134);
|
||||
assert_eq!(serialized_str.len(), 108);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ pub(crate) mod record;
|
|||
pub(crate) mod shared;
|
||||
|
||||
/// Derive macro for the module.
|
||||
#[proc_macro_derive(Module)]
|
||||
#[proc_macro_derive(Module, attributes(module))]
|
||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
module::derive_impl(&input)
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::{display, record::ModuleRecordCodegen};
|
|||
use crate::shared::generics::GenericsHelper;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Generics};
|
||||
use syn::{parse_quote, Attribute, Generics};
|
||||
|
||||
/// Basic trait to be implemented for Module generation.
|
||||
pub(crate) trait ModuleCodegen {
|
||||
|
@ -30,8 +30,8 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
|
||||
let generics = GenericsParser::from_ast(&ast.generics);
|
||||
|
||||
let display_fn = display::display_fn(name);
|
||||
|
||||
let display_fn = display::display_fn(ast);
|
||||
let attributes_fn = display::attributes_fn(ast);
|
||||
let num_params_fn = codegen.gen_num_params();
|
||||
let visit = codegen.gen_visit();
|
||||
let map_mut = codegen.gen_map();
|
||||
|
@ -54,7 +54,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
|
||||
let generics_ty_inner_module = generics.inner_module_ty;
|
||||
|
||||
let gen = quote! {
|
||||
let mut gen = quote! {
|
||||
impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {
|
||||
type Record = #record_name #generics_ty_module;
|
||||
|
||||
|
@ -69,6 +69,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#collect_devices
|
||||
#to_device
|
||||
#fork
|
||||
|
||||
}
|
||||
|
||||
impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff
|
||||
|
@ -82,6 +83,15 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#display_fn
|
||||
}
|
||||
|
||||
|
||||
impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {
|
||||
#attributes_fn
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
burn::module::Module::num_params(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl #generics_module Clone for #name #generics_ty_module #generics_where_module {
|
||||
#clone_fn
|
||||
}
|
||||
|
@ -89,13 +99,21 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
|||
#record_type
|
||||
};
|
||||
|
||||
if !has_custom_display(&ast.attrs) {
|
||||
gen.extend(quote! {
|
||||
impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
gen
|
||||
}
|
||||
|
||||
// When there is no backend in the generic parameter, the type is considered as a constant.
|
||||
pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
let (_generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
|
||||
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
|
||||
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};
|
||||
|
@ -112,7 +130,10 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let (generics_module, _, _) = generics_module.split_for_impl();
|
||||
let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();
|
||||
|
||||
let gen = quote! {
|
||||
let display_fn = display::display_fn(ast);
|
||||
let attributes_fn = display::attributes_fn(ast);
|
||||
|
||||
let mut gen = quote! {
|
||||
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
|
||||
burn::constant!(module);
|
||||
}
|
||||
|
@ -121,8 +142,26 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
|||
for #name #generics_ty #generics_where {
|
||||
burn::constant!(ad_module, #name #generics_ty);
|
||||
}
|
||||
|
||||
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
|
||||
#display_fn
|
||||
}
|
||||
|
||||
|
||||
impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {
|
||||
#attributes_fn
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
if !has_custom_display(&ast.attrs) {
|
||||
gen.extend(quote! {
|
||||
impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
gen
|
||||
}
|
||||
|
||||
|
@ -159,22 +198,64 @@ impl GenericsParser {
|
|||
#ident: burn::module::Module<B>
|
||||
}
|
||||
);
|
||||
|
||||
module.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplayDefault
|
||||
}
|
||||
);
|
||||
|
||||
module.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::AutodiffModule<B>
|
||||
}
|
||||
);
|
||||
module_autodiff.add_predicate(
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::Module<B::InnerBackend>
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplayDefault
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
});
|
||||
|
||||
module.consts().into_iter().for_each(|ident| {
|
||||
|
@ -188,3 +269,18 @@ impl GenericsParser {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_custom_display(attrs: &[Attribute]) -> bool {
|
||||
attrs.iter().any(|attr| {
|
||||
attr.path().is_ident("module")
|
||||
&& attr
|
||||
.parse_nested_meta(|meta| {
|
||||
if meta.path.is_ident("custom_display") {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(meta.error("unsupported attribute"))
|
||||
}
|
||||
})
|
||||
.is_ok()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,11 +1,96 @@
|
|||
use proc_macro2::Ident;
|
||||
use quote::quote;
|
||||
|
||||
pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream {
|
||||
pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
|
||||
match &ast.data {
|
||||
syn::Data::Struct(ref data_struct) => {
|
||||
let fields = match &data_struct.fields {
|
||||
syn::Fields::Named(ref named_fields) => {
|
||||
named_fields.named.iter().collect::<Vec<_>>()
|
||||
}
|
||||
syn::Fields::Unit => Vec::new(),
|
||||
_ => panic!("attributes_fn only supports structs with named or unit fields"),
|
||||
};
|
||||
let field_prints = fields.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { .add(stringify!(#field_name), &self.#field_name) }
|
||||
});
|
||||
let struct_name = &ast.ident;
|
||||
quote! {
|
||||
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
content
|
||||
.set_top_level_type(&stringify!(#struct_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(ref data_enum) => {
|
||||
let variant_prints = data_enum.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
match &variant.fields {
|
||||
syn::Fields::Unit => {
|
||||
quote! {
|
||||
Self::#variant_name => {
|
||||
content.add_formatted(&stringify!(#variant_name).to_string())
|
||||
.optional()
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Fields::Named(ref named_fields) => {
|
||||
let field_prints = named_fields.named.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { .add(stringify!(#field_name), &self.#field_name) }
|
||||
});
|
||||
|
||||
let field_names = named_fields.named.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { #field_name }
|
||||
});
|
||||
|
||||
quote! {
|
||||
Self::#variant_name { #(#field_names),* } => {
|
||||
content.set_top_level_type(&stringify!(#variant_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Fields::Unnamed(ref unnamed_fields) => {
|
||||
let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {
|
||||
syn::Ident::new(&format!("_{}", i), proc_macro2::Span::call_site())
|
||||
});
|
||||
|
||||
let field_prints = field_names.clone().map(|field_name| {
|
||||
quote! { .add(stringify!(#field_name), #field_name) }
|
||||
});
|
||||
quote! {
|
||||
Self::#variant_name(#(#field_names),*) => {
|
||||
content.set_top_level_type(&stringify!(#variant_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
quote! {
|
||||
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
match self {
|
||||
#(#variant_prints)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => panic!("attributes_fn only supports structs and enums"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "{}[num_params={}]", stringify!(#name), self.num_params())
|
||||
|
||||
let formatted = burn::module::ModuleDisplay::format(self, Default::default());
|
||||
write!(f, "{}", formatted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,7 +151,7 @@ impl Display for LearnerSummary {
|
|||
)?;
|
||||
|
||||
if let Some(model) = &self.model {
|
||||
writeln!(f, "Model: {model}")?;
|
||||
writeln!(f, "Model:\n{model}")?;
|
||||
}
|
||||
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
|
||||
|
||||
|
|
Loading…
Reference in New Issue