Print model structure like with PyTorch - Part 1 (#1912)

This commit is contained in:
Dilshod Tadjibaev 2024-06-25 08:23:10 -05:00 committed by GitHub
parent 3faf544bc4
commit 2c51615471
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1135 additions and 78 deletions

8
Cargo.lock generated
View File

@ -463,6 +463,7 @@ version = "0.14.0"
dependencies = [
"async-trait",
"dashmap",
"data-encoding",
"derive-new",
"getrandom",
"indicatif",
@ -471,7 +472,6 @@ dependencies = [
"serde",
"spin",
"tokio",
"uuid",
"web-time",
]
@ -1469,6 +1469,12 @@ dependencies = [
"parking_lot_core 0.9.10",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
[[package]]
name = "deflate64"
version = "0.1.8"

View File

@ -34,6 +34,9 @@ colored = "2.1.0"
console_error_panic_hook = "0.1.7"
csv = "1.3.0"
dashmap = "5.5.3"
data-encoding = { version = "2.6.0", default-features = false, features = [
"alloc",
] }
dirs = "5.0.1"
fake = "2.9.2"
flate2 = "1.0.30"
@ -42,16 +45,19 @@ getrandom = { version = "0.2.15", default-features = false }
gix-tempfile = { version = "13.1.1", features = ["signals"] }
globwalk = "0.9.1"
hashbrown = "0.14.5"
hound = "3.5.1"
image = "0.25.1"
indicatif = "0.17.8"
js-sys = "0.3.69"
libm = "0.2.8"
log = { default-features = false, version = "0.4.21" }
md5 = "0.7.0"
percent-encoding = "2.3.1"
pretty_assertions = "1.4.0"
proc-macro2 = "1.0.85"
protobuf = "3.4.0"
protobuf-codegen = "3.4.0"
quote = "1.0.36"
percent-encoding = "2.3.1"
r2d2 = "0.8.10"
r2d2_sqlite = { version = "0.24.0" }
rayon = "1.10.0"
@ -63,6 +69,7 @@ rusqlite = { version = "0.31.0" }
rust-format = { version = "0.3.4" }
sanitize-filename = "0.5.0"
serde_rusqlite = "0.35.0"
serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
strum = "0.26.2"
strum_macros = "0.26.4"
@ -73,11 +80,7 @@ tokio = { version = "1.38.0", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.32"
tracing-subscriber = "0.3.18"
md5 = "0.7.0"
serial_test = "3.1.1"
web-time = "1.1.0"
hound = "3.5.1"
image = "0.25.1"
zip = "2.1.3"
# Terminal UI

View File

@ -12,7 +12,7 @@ version.workspace = true
[features]
default = ["std"]
std = ["rand/std"]
std = ["rand/std", "data-encoding/std"]
doc = ["default"]
wasm-sync = []
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
@ -27,10 +27,10 @@ web-time = { version = "1.1.0" }
# ** Please make sure all dependencies support no_std when std is disabled **
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
uuid = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
derive-new = { workspace = true }
serde = { workspace = true }
data-encoding = { workspace = true }
# Network downloader
indicatif = { workspace = true, optional = true }

View File

@ -1,18 +1,21 @@
use alloc::string::String;
use crate::rand::gen_random;
use alloc::string::{String, ToString};
use uuid::{Builder, Bytes};
use data_encoding::BASE32_DNSSEC;
/// Simple ID generator.
pub struct IdGenerator {}
impl IdGenerator {
/// Generates a new ID in the form of a UUID.
/// Generates a new ID.
pub fn generate() -> String {
let random_bytes: Bytes = gen_random();
// Generate 6 random bytes (281,474,976,710,656 combinations)
let random_bytes: [u8; 6] = gen_random();
let uuid = Builder::from_random_bytes(random_bytes).into_uuid();
uuid.as_hyphenated().to_string()
// Encode the random bytes in base32 DNSSEC
// 6 bytes encodes to 10 lower case characters, e.g. "3uu5e6vv7c"
BASE32_DNSSEC.encode(&random_bytes)
}
}

View File

@ -0,0 +1,547 @@
use alloc::{
borrow::ToOwned,
format,
string::{String, ToString},
vec::Vec,
};
use core::any;
use core::fmt::{Display, Write};
/// Default display settings for a module.
pub trait ModuleDisplayDefault {
/// Attributes of the module used for display purposes.
///
/// # Arguments
///
/// * `_content` - The content object that contains display settings and attributes.
///
/// # Returns
///
/// An optional content object containing the display attributes.
fn content(&self, _content: Content) -> Option<Content>;
/// Gets the number of the parameters of the module.
fn num_params(&self) -> usize {
0
}
}
/// Trait to implement custom display settings for a module.
///
/// In order to implement custom display settings for a module,
/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
/// 2. Implement ModuleDisplay trait for the module
pub trait ModuleDisplay: ModuleDisplayDefault {
/// Formats the module with provided display settings.
///
/// # Arguments
///
/// * `passed_settings` - Display settings passed to the module.
///
/// # Returns
///
/// A string representation of the formatted module.
fn format(&self, passed_settings: DisplaySettings) -> String {
let settings = if let Some(custom_settings) = self.custom_settings() {
custom_settings.inherit(passed_settings)
} else {
passed_settings
};
let indent = " ".repeat(settings.level * settings.indentation_size());
let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
let settings = settings.level_up();
let self_type = extract_type_name::<Self>();
// Use custom content if it is implemented and show_all_attributes is false,
// otherwise use default content
let content = if !settings.show_all_attributes() {
self.custom_content(Content::new(settings.clone()))
.unwrap_or_else(|| {
self.content(Content::new(settings.clone()))
.unwrap_or_else(|| {
panic!("Default content should be implemented for {self_type}.")
})
})
} else {
self.content(Content::new(settings.clone()))
.unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
};
let top_level_type = if let Some(top_level_type) = content.top_level_type {
top_level_type.to_owned()
} else {
self_type.to_owned()
};
// If there is only one item in the content, return it or no attributes
if let Some(item) = content.single_item {
return item;
} else if content.attributes.is_empty() {
return top_level_type.to_string();
}
let mut result = String::new();
// Print the struct name
if settings.new_line_after_attribute() {
writeln!(result, "{} {{", top_level_type).unwrap();
} else {
write!(result, "{} {{", top_level_type).unwrap();
}
for (i, attribute) in content.attributes.iter().enumerate() {
if settings.new_line_after_attribute() {
writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
} else if i == 0 {
write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
} else {
write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
}
}
if settings.show_num_parameters() {
let num_params = self.num_params();
if num_params > 0 {
if settings.new_line_after_attribute() {
writeln!(result, "{indent}params: {}", num_params).unwrap();
} else {
write!(result, ", params: {}", num_params).unwrap();
}
}
}
if settings.new_line_after_attribute() {
write!(result, "{indent_close_braces}}}").unwrap();
} else {
write!(result, "}}").unwrap();
}
result
}
/// Custom display settings for the module.
///
/// # Returns
///
/// An optional display settings object.
fn custom_settings(&self) -> Option<DisplaySettings> {
None
}
/// Custom attributes for the module.
///
/// # Arguments
///
/// * `_content` - The content object that contains display settings and attributes.
///
/// # Returns
///
/// An optional content object containing the custom attributes.
fn custom_content(&self, _content: Content) -> Option<Content> {
None
}
}
/// Custom module display settings.
#[derive(Debug, Clone)]
pub struct DisplaySettings {
/// Whether to print the module parameter ids.
show_param_id: Option<bool>,
/// Whether to print the module attributes.
show_all_attributes: Option<bool>,
/// Whether to print the module number of parameters.
show_num_parameters: Option<bool>,
/// Print new line after an attribute.
new_line_after_attribute: Option<bool>,
/// Indentation size.
indentation_size: Option<usize>,
/// Level of indentation.
level: usize,
}
impl Default for DisplaySettings {
fn default() -> Self {
DisplaySettings {
show_param_id: None,
show_all_attributes: None,
show_num_parameters: None,
new_line_after_attribute: None,
indentation_size: None,
level: 1,
}
}
}
impl DisplaySettings {
/// Create a new format settings.
///
/// # Returns
///
/// A new instance of `DisplaySettings`.
pub fn new() -> Self {
Default::default()
}
/// Sets a flag to show module parameters.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show module parameters.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_param_id(mut self, flag: bool) -> Self {
self.show_param_id = Some(flag);
self
}
/// Sets a flag to show module attributes.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show all module attributes.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
self.show_all_attributes = Some(flag);
self
}
/// Sets a flag to show the number of module parameters.
///
/// # Arguments
///
/// * `flag` - Boolean flag to show the number of module parameters.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
self.show_num_parameters = Some(flag);
self
}
/// Sets a flag to print a new line after an attribute.
///
/// # Arguments
///
/// * `flag` - Boolean flag to print a new line after an attribute.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
self.new_line_after_attribute = Some(flag);
self
}
/// Sets the indentation size.
///
/// # Arguments
///
/// * `size` - The size of the indentation.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn with_indentation_size(mut self, size: usize) -> Self {
self.indentation_size = Some(size);
self
}
/// Inherits settings from the provided settings and return a new settings object.
///
/// # Arguments
///
/// * `top` - The top level `DisplaySettings` to inherit from.
///
/// # Returns
///
/// Updated `DisplaySettings` instance.
pub fn inherit(self, top: Self) -> Self {
let mut updated = self.clone();
if let Some(show_param_id) = top.show_param_id {
updated.show_param_id = Some(show_param_id);
};
if let Some(show_all_attributes) = top.show_all_attributes {
updated.show_all_attributes = Some(show_all_attributes);
}
if let Some(show_num_parameters) = top.show_num_parameters {
updated.show_num_parameters = Some(show_num_parameters);
}
if let Some(new_line_after_attribute) = top.new_line_after_attribute {
updated.new_line_after_attribute = Some(new_line_after_attribute);
}
if let Some(indentation_size) = top.indentation_size {
updated.indentation_size = Some(indentation_size);
}
updated.level = top.level;
updated
}
/// A convenience method to wrap the DisplaySettings struct in an option.
///
/// # Returns
///
/// An optional `DisplaySettings`.
pub fn optional(self) -> Option<Self> {
Some(self)
}
/// Increases the level of indentation.
///
/// # Returns
///
/// Updated `DisplaySettings` instance with increased indentation level.
pub fn level_up(mut self) -> Self {
self.level += 1;
self
}
/// Gets `show_param_id` flag, substitutes false if not set.
///
/// This flag is used to print the module parameter ids.
///
/// # Returns
///
/// A boolean value indicating whether to show parameter ids.
pub fn show_param_id(&self) -> bool {
self.show_param_id.unwrap_or(false)
}
/// Gets `show_all_attributes`, substitutes false if not set.
///
/// This flag is used to force to print all module attributes, overriding custom attributes.
///
/// # Returns
///
/// A boolean value indicating whether to show all attributes.
pub fn show_all_attributes(&self) -> bool {
self.show_all_attributes.unwrap_or(false)
}
/// Gets `show_num_parameters`, substitutes true if not set.
///
/// This flag is used to print the number of module parameters.
///
/// # Returns
///
/// A boolean value indicating whether to show the number of parameters.
pub fn show_num_parameters(&self) -> bool {
self.show_num_parameters.unwrap_or(true)
}
/// Gets `new_line_after_attribute`, substitutes true if not set.
///
/// This flag is used to print a new line after an attribute.
///
/// # Returns
///
/// A boolean value indicating whether to print a new line after an attribute.
pub fn new_line_after_attribute(&self) -> bool {
self.new_line_after_attribute.unwrap_or(true)
}
/// Gets `indentation_size`, substitutes 2 if not set.
///
/// This flag is used to set the size of indentation.
///
/// # Returns
///
/// An integer value indicating the size of indentation.
pub fn indentation_size(&self) -> usize {
self.indentation_size.unwrap_or(2)
}
}
/// Struct to store the attributes of a module for formatting.
#[derive(Clone, Debug)]
pub struct Content {
/// List of attributes.
pub attributes: Vec<Attribute>,
/// Single item content.
pub single_item: Option<String>,
/// Display settings.
pub display_settings: DisplaySettings,
/// Top level type name.
pub top_level_type: Option<String>,
}
impl Content {
/// Creates a new attributes struct.
///
/// # Arguments
///
/// * `display_settings` - Display settings for the content.
///
/// # Returns
///
/// A new instance of `Content`.
pub fn new(display_settings: DisplaySettings) -> Self {
Content {
attributes: Vec::new(),
single_item: None,
display_settings,
top_level_type: None,
}
}
/// Adds an attribute to the format settings. The value will be formatted and stored as a string.
///
/// # Arguments
///
/// * `name` - Name of the attribute.
/// * `value` - Value of the attribute.
///
/// # Returns
///
/// Updated `Content` instance with the new attribute added.
pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
if self.single_item.is_some() {
panic!("Cannot add multiple attributes when single item is set.");
}
let attribute = Attribute {
name: name.to_owned(),
value: value.format(self.display_settings.clone()), // TODO level + 1
ty: any::type_name::<T>().to_string(),
};
self.attributes.push(attribute);
self
}
/// Adds a single item.
///
/// # Arguments
///
/// * `value` - Rendered string of the single item.
///
/// # Returns
///
/// Updated `Content` instance with the single item added.
pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
if !self.attributes.is_empty() {
panic!("Cannot add single item when attributes are set.");
}
self.single_item = Some(value.format(self.display_settings.clone()));
self
}
/// Adds a single item.
///
/// # Arguments
///
/// * `value` - Formatted display value.
///
/// # Returns
///
/// Updated `Content` instance with the formatted single item added.
pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
if !self.attributes.is_empty() {
panic!("Cannot add single item when attributes are set.");
}
self.single_item = Some(format!("{}", value));
self
}
/// A convenience method to wrap the Attributes struct in an option
/// because it is often used as an optional field.
///
/// # Returns
///
/// An optional `Content`.
pub fn optional(self) -> Option<Self> {
if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
{
None
} else {
Some(self)
}
}
/// Sets the top level type name.
///
/// # Arguments
///
/// * `ty` - The type name to set.
///
/// # Returns
///
/// Updated `Content` instance with the top level type name set.
pub fn set_top_level_type(mut self, ty: &str) -> Self {
self.top_level_type = Some(ty.to_owned());
self
}
}
/// Attribute to print in the display method.
#[derive(Clone, Debug)]
pub struct Attribute {
/// Name of the attribute.
pub name: String,
/// Value of the attribute.
pub value: String,
/// Type of the attribute.
pub ty: String,
}
/// Extracts the short name of a type T
///
/// # Returns
///
/// A string slice representing the short name of the type.
pub fn extract_type_name<T: ?Sized>() -> &'static str {
// Get the full type name of T, including module path and generic parameters
let ty = any::type_name::<T>();
// Find the first occurrence of '<' in the full type name
// If not found, use the length of the type name
let end = ty.find('<').unwrap_or(ty.len());
// Slice the type name up to the first '<' or the end
let ty = &ty[0..end];
// Find the last occurrence of "::" in the sliced type name
// If found, add 2 to skip the "::" itself
// If not found, start from the beginning of the type name
let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
// Find the last occurrence of '<' in the sliced type name
// If not found, use the length of the type name
let end = ty.rfind('<').unwrap_or(ty.len());
// If the start index is less than the end index,
// return the slice of the type name from start to end
// Otherwise, return the entire sliced type name
if start < end {
&ty[start..end]
} else {
ty
}
}

View File

@ -1,5 +1,7 @@
mod base;
mod display;
mod param;
pub use base::*;
pub use display::*;
pub use param::*;

View File

@ -1,6 +1,12 @@
use alloc::{format, string::ToString};
use core::{fmt::Display, marker::PhantomData};
use crate::{
self as burn,
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
module::{
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
ModuleMapper, ModuleVisitor,
},
record::Record,
};
use burn::record::PrecisionSettings;
@ -8,7 +14,6 @@ use burn_tensor::{
backend::{AutodiffBackend, Backend},
BasicAutodiffOps, BasicOps, Tensor,
};
use core::marker::PhantomData;
/// Record used for constant type implementing the [module](crate::module::Module) trait.
#[derive(Debug, Clone, Copy, new, Default)]
@ -96,6 +101,15 @@ macro_rules! constant {
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
constant!(ad_module, $type);
}
impl burn::module::ModuleDisplayDefault for $type {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
let string = format!("{}", self);
content.add_formatted(&string).optional()
}
}
impl burn::module::ModuleDisplay for $type {}
};
}
@ -122,6 +136,13 @@ constant!(i32);
constant!(i16);
constant!(i8);
impl burn::module::ModuleDisplay for str {}
impl burn::module::ModuleDisplayDefault for str {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
content.add_formatted(&self).optional()
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
type Record = ConstantRecord;
@ -158,6 +179,15 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
fn content(&self, content: Content) -> Option<Content> {
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
content.add_single(&string).optional()
}
}
impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
@ -200,6 +230,14 @@ impl<B: Backend> Module<B> for PhantomData<B> {
}
}
impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
fn content(&self, content: Content) -> Option<Content> {
content.add_single(&"PhantomData".to_string()).optional()
}
}
impl<B: Backend> ModuleDisplay for PhantomData<B> {}
impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
type InnerModule = PhantomData<B::InnerBackend>;
@ -248,6 +286,27 @@ where
}
}
impl<T> ModuleDisplayDefault for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn content(&self, content: Content) -> Option<Content> {
// For now, just print the debug representation of the ignored value
content.add_single(&format!("{:?}", self.0)).optional()
}
}
impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}
impl<T> Display for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
where
B: AutodiffBackend,

View File

@ -1,5 +1,10 @@
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
use alloc::vec::Vec;
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use alloc::{format, vec::Vec};
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::fmt::Debug;
@ -52,6 +57,17 @@ where
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
fn content(&self, content: Content) -> Option<Content> {
match self {
Some(module) => content.add_single(module).optional(),
None => content.add_single("None").optional(),
}
}
}
impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}
impl<T, B> AutodiffModule<B> for Option<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
@ -128,6 +144,21 @@ where
}
}
impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{}", i);
acc.add(&index, module)
})
.set_top_level_type(format!("Vec<0..{}>", self.len()).as_str())
.optional()
}
}
impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}
impl<T, B> AutodiffModule<B> for Vec<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
@ -197,6 +228,21 @@ where
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{}", i);
acc.add(&index, module)
})
.set_top_level_type(format!("[0..{}]", self.len()).as_str())
.optional()
}
}
impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}
impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
@ -269,6 +315,21 @@ macro_rules! impl_module_tuple {
($(self.$i.valid(),)*)
}
}
impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)
where
$($l: ModuleDisplay,)*
{
fn content(&self, content: Content) -> Option<Content> {
let content = content
$(.add(&format!("{}", $i), &self.$i))*
.set_top_level_type(format!("({})", stringify!($($l),*)).as_str());
content.optional()
}
}
impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}
};
}

View File

@ -1,7 +1,13 @@
use super::ParamId;
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor, Param,
};
use alloc::string::ToString;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::stub::Mutex;
use burn_tensor::{
backend::{AutodiffBackend, Backend},
@ -45,6 +51,24 @@ pub struct RunningState<V> {
value: Arc<Mutex<V>>,
}
// Implement display for the module
impl<V> core::fmt::Display for RunningState<V> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "RunningState(id={})", self.id)
}
}
impl<V> ModuleDisplayDefault for RunningState<V> {
fn content(&self, content: Content) -> Option<Content> {
content
.add_formatted(&"RunningState".to_string())
.optional()
}
}
impl<V> ModuleDisplay for RunningState<V> {}
impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;

View File

@ -1,10 +1,13 @@
use super::{Param, ParamId, Parameter};
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use crate::tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
};
use alloc::vec::Vec;
use alloc::{format, string::ToString, vec::Vec};
use burn_tensor::{Bool, Data, Float, Int};
impl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {
@ -147,6 +150,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}",
self.shape().dims
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
type Record = Param<Tensor<B, D, Int>>;
@ -198,6 +217,22 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}",
self.shape().dims
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
type Record = Param<Tensor<B, D, Bool>>;
@ -249,6 +284,24 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
}
}
impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
fn content(&self, content: Content) -> Option<Content> {
let id = if content.display_settings.show_param_id() {
format!(", id: {}", self.id)
} else {
"".to_string()
};
let string = format!(
"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}",
self.shape().dims
);
content.add_formatted(&string).optional()
}
}
impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}
impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
type InnerModule = Param<Tensor<B::InnerBackend, D>>;

View File

@ -1,14 +1,13 @@
use alloc::format;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::nn::conv::checks;
use crate::nn::{Initializer, PaddingConfig1d};
use crate::tensor::backend::Backend;
use crate::tensor::module::conv1d;
use crate::tensor::ops::ConvOptions;
use crate::tensor::Tensor;
use crate::{
config::Config,
module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param},
nn::{conv::checks, Initializer, PaddingConfig1d},
tensor::{backend::Backend, module::conv1d, ops::ConvOptions, Tensor},
};
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
#[derive(Config, Debug)]
@ -45,6 +44,7 @@ pub struct Conv1dConfig {
///
/// Should be created with [Conv1dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv1d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
pub weight: Param<Tensor<B, 3>>,
@ -54,7 +54,28 @@ pub struct Conv1d<B: Backend> {
kernel_size: usize,
dilation: usize,
groups: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
}
impl<B: Backend> ModuleDisplay for Conv1d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
content
.add("stride", &self.stride)
.add("kernel_size", &self.kernel_size)
.add("dilation", &self.dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl Conv1dConfig {
@ -87,7 +108,7 @@ impl Conv1dConfig {
bias,
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
groups: self.groups,
}

View File

@ -1,8 +1,9 @@
use alloc::format;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
use crate::nn::Initializer;
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
@ -45,6 +46,7 @@ pub struct Conv2dConfig {
///
/// Should be created with [Conv2dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Conv2d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
@ -54,7 +56,7 @@ pub struct Conv2d<B: Backend> {
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
}
impl Conv2dConfig {
@ -93,12 +95,38 @@ impl Conv2dConfig {
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
groups: self.groups,
}
}
}
impl<B: Backend> ModuleDisplay for Conv2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
content
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("groups", &self.groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl<B: Backend> Conv2d<B> {
/// Applies the forward pass on the input tensor.
///

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};
@ -21,6 +21,7 @@ pub struct DropoutConfig {
///
/// Should be created with [DropoutConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Dropout {
prob: f64,
}
@ -54,6 +55,18 @@ impl Dropout {
}
}
impl ModuleDisplay for Dropout {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
content.add("prob", &self.prob).optional()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,4 +1,6 @@
use crate as burn;
use crate::module::DisplaySettings;
use crate::module::ModuleDisplay;
use crate::config::Config;
use crate::module::Module;
@ -30,6 +32,7 @@ pub struct LinearConfig {
///
/// `O = IW + b`
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Linear<B: Backend> {
/// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
@ -83,6 +86,23 @@ impl<B: Backend> Linear<B> {
}
}
impl<B: Backend> ModuleDisplay for Linear<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
let [d_input, d_output] = self.weight.shape().dims;
content
.add("d_input", &d_input)
.add("d_output", &d_output)
.add("bias", &self.bias.is_some())
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,4 +1,5 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::nn::Initializer;
use crate::{
@ -33,6 +34,7 @@ pub struct BatchNormConfig {
///
/// Should be created using [BatchNormConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BatchNorm<B: Backend, const D: usize> {
/// The learnable weight gamma.
pub gamma: Param<Tensor<B, 1>>,
@ -183,6 +185,24 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
}
}
impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [num_features] = self.beta.shape().dims;
content
.add("num_features", &num_features)
.add("momentum", &self.momentum)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests_1d {

View File

@ -1,7 +1,8 @@
use crate as burn;
use crate::config::Config;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
use crate::module::Param;
use crate::nn::Initializer;
use crate::tensor::backend::Backend;
@ -29,6 +30,7 @@ pub struct LayerNormConfig {
///
/// Should be created using [LayerNormConfig](LayerNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct LayerNorm<B: Backend> {
/// The learnable weight.
gamma: Param<Tensor<B, 1>>,
@ -71,6 +73,22 @@ impl<B: Backend> LayerNorm<B> {
}
}
impl<B: Backend> ModuleDisplay for LayerNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
let [d_model] = self.gamma.shape().dims;
content
.add("d_model", &d_model)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -3,10 +3,9 @@ use crate as burn;
use crate::tensor::ops::conv::calculate_conv_padding;
use crate::config::Config;
use crate::module::Module;
/// Padding configuration for 1D operators.
#[derive(Module, Config, Debug, PartialEq)]
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig1d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
@ -34,7 +33,7 @@ impl PaddingConfig1d {
}
/// Padding configuration for 2D operators.
#[derive(Module, Config, Debug, PartialEq)]
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig2d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -43,7 +43,7 @@ pub struct AvgPool1dConfig {
pub struct AvgPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
count_include_pad: bool,
}
@ -53,7 +53,7 @@ impl AvgPool1dConfig {
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -42,7 +42,7 @@ pub struct AvgPool2dConfig {
pub struct AvgPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
count_include_pad: bool,
}
@ -52,7 +52,7 @@ impl AvgPool2dConfig {
AvgPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -31,7 +31,7 @@ pub struct MaxPool1dConfig {
pub struct MaxPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
dilation: usize,
}
@ -41,7 +41,7 @@ impl MaxPool1dConfig {
MaxPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -31,7 +31,7 @@ pub struct MaxPool2dConfig {
pub struct MaxPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
dilation: [usize; 2],
}
@ -41,7 +41,7 @@ impl MaxPool2dConfig {
MaxPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
dilation: self.dilation,
}
}

View File

@ -1,12 +1,11 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::ops::UnfoldOptions;
use crate::tensor::Tensor;
use crate::tensor::module::unfold4d;
use crate::module::{Ignored, Module};
use burn_tensor::backend::Backend;
use burn_tensor::module::unfold4d;
use burn_tensor::ops::UnfoldOptions;
use burn_tensor::Tensor;
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
#[derive(Config, Debug)]
@ -29,14 +28,14 @@ pub struct Unfold4dConfig {
/// Should be created with [Unfold4dConfig].
#[derive(Module, Clone, Debug)]
pub struct Unfold4d {
config: Unfold4dConfig,
config: Ignored<Unfold4dConfig>,
}
impl Unfold4dConfig {
/// Initializes a new [Unfold4d] module.
pub fn init(&self) -> Unfold4d {
Unfold4d {
config: self.clone(),
config: Ignored(self.clone()),
}
}
}
@ -48,7 +47,7 @@ impl Unfold4d {
///
/// # Shapes
///
/// input: `[batch_size, channels_in, height, width]`
/// input: `[batch_size, channels_in, height, width]`
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
unfold4d(

View File

@ -370,6 +370,6 @@ mod tests {
// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 134);
assert_eq!(serialized_str.len(), 108);
}
}

View File

@ -13,7 +13,7 @@ pub(crate) mod record;
pub(crate) mod shared;
/// Derive macro for the module.
#[proc_macro_derive(Module)]
#[proc_macro_derive(Module, attributes(module))]
pub fn module_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
module::derive_impl(&input)

View File

@ -2,7 +2,7 @@ use super::{display, record::ModuleRecordCodegen};
use crate::shared::generics::GenericsHelper;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use syn::{parse_quote, Attribute, Generics};
/// Basic trait to be implemented for Module generation.
pub(crate) trait ModuleCodegen {
@ -30,8 +30,8 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
let generics = GenericsParser::from_ast(&ast.generics);
let display_fn = display::display_fn(name);
let display_fn = display::display_fn(ast);
let attributes_fn = display::attributes_fn(ast);
let num_params_fn = codegen.gen_num_params();
let visit = codegen.gen_visit();
let map_mut = codegen.gen_map();
@ -54,7 +54,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
let generics_ty_inner_module = generics.inner_module_ty;
let gen = quote! {
let mut gen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {
type Record = #record_name #generics_ty_module;
@ -69,6 +69,7 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#collect_devices
#to_device
#fork
}
impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff
@ -82,6 +83,15 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#display_fn
}
impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {
#attributes_fn
fn num_params(&self) -> usize {
burn::module::Module::num_params(self)
}
}
impl #generics_module Clone for #name #generics_ty_module #generics_where_module {
#clone_fn
}
@ -89,13 +99,21 @@ pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
#record_type
};
if !has_custom_display(&ast.attrs) {
gen.extend(quote! {
impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {
}
});
}
gen
}
// When there is no backend in the generic parameter, the type is considered as a constant.
pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (_generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};
@ -112,7 +130,10 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
let (generics_module, _, _) = generics_module.split_for_impl();
let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();
let gen = quote! {
let display_fn = display::display_fn(ast);
let attributes_fn = display::attributes_fn(ast);
let mut gen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
burn::constant!(module);
}
@ -121,8 +142,26 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
for #name #generics_ty #generics_where {
burn::constant!(ad_module, #name #generics_ty);
}
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {
#attributes_fn
}
};
if !has_custom_display(&ast.attrs) {
gen.extend(quote! {
impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {
}
});
}
gen
}
@ -159,22 +198,64 @@ impl GenericsParser {
#ident: burn::module::Module<B>
}
);
module.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplayDefault
}
);
module.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::AutodiffModule<B>
}
);
module_autodiff.add_predicate(
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>
}
);
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
}
);
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
}
);
generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::Module<B::InnerBackend>
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplayDefault
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
});
module.consts().into_iter().for_each(|ident| {
@ -188,3 +269,18 @@ impl GenericsParser {
}
}
}
fn has_custom_display(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
attr.path().is_ident("module")
&& attr
.parse_nested_meta(|meta| {
if meta.path.is_ident("custom_display") {
Ok(())
} else {
Err(meta.error("unsupported attribute"))
}
})
.is_ok()
})
}

View File

@ -1,11 +1,96 @@
use proc_macro2::Ident;
use quote::quote;
pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream {
pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
match &ast.data {
syn::Data::Struct(ref data_struct) => {
let fields = match &data_struct.fields {
syn::Fields::Named(ref named_fields) => {
named_fields.named.iter().collect::<Vec<_>>()
}
syn::Fields::Unit => Vec::new(),
_ => panic!("attributes_fn only supports structs with named or unit fields"),
};
let field_prints = fields.iter().map(|field| {
let field_name = &field.ident;
quote! { .add(stringify!(#field_name), &self.#field_name) }
});
let struct_name = &ast.ident;
quote! {
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
content
.set_top_level_type(&stringify!(#struct_name))
#(#field_prints)*
.optional()
}
}
}
syn::Data::Enum(ref data_enum) => {
let variant_prints = data_enum.variants.iter().map(|variant| {
let variant_name = &variant.ident;
match &variant.fields {
syn::Fields::Unit => {
quote! {
Self::#variant_name => {
content.add_formatted(&stringify!(#variant_name).to_string())
.optional()
}
}
}
syn::Fields::Named(ref named_fields) => {
let field_prints = named_fields.named.iter().map(|field| {
let field_name = &field.ident;
quote! { .add(stringify!(#field_name), &self.#field_name) }
});
let field_names = named_fields.named.iter().map(|field| {
let field_name = &field.ident;
quote! { #field_name }
});
quote! {
Self::#variant_name { #(#field_names),* } => {
content.set_top_level_type(&stringify!(#variant_name))
#(#field_prints)*
.optional()
}
}
}
syn::Fields::Unnamed(ref unnamed_fields) => {
let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {
syn::Ident::new(&format!("_{}", i), proc_macro2::Span::call_site())
});
let field_prints = field_names.clone().map(|field_name| {
quote! { .add(stringify!(#field_name), #field_name) }
});
quote! {
Self::#variant_name(#(#field_names),*) => {
content.set_top_level_type(&stringify!(#variant_name))
#(#field_prints)*
.optional()
}
}
}
}
});
quote! {
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
match self {
#(#variant_prints)*
}
}
}
}
_ => panic!("attributes_fn only supports structs and enums"),
}
}
pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
quote! {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}[num_params={}]", stringify!(#name), self.num_params())
let formatted = burn::module::ModuleDisplay::format(self, Default::default());
write!(f, "{}", formatted)
}
}
}

View File

@ -151,7 +151,7 @@ impl Display for LearnerSummary {
)?;
if let Some(model) = &self.model {
writeln!(f, "Model: {model}")?;
writeln!(f, "Model:\n{model}")?;
}
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;