From 98a58c867d420d724cb81e8aa15c7a5843aaa7a1 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 28 Jun 2024 07:37:40 -0500 Subject: [PATCH] Print module - implement module display for remaining modules (part2) (#1933) --- crates/burn-core/src/nn/attention/mha.rs | 67 +++++++++++++--- crates/burn-core/src/nn/conv/conv1d.rs | 26 +++++-- crates/burn-core/src/nn/conv/conv2d.rs | 26 +++++-- .../burn-core/src/nn/conv/conv_transpose1d.rs | 58 ++++++++++++-- .../burn-core/src/nn/conv/conv_transpose2d.rs | 58 ++++++++++++-- crates/burn-core/src/nn/dropout.rs | 15 +++- crates/burn-core/src/nn/embedding.rs | 29 +++++++ crates/burn-core/src/nn/gelu.rs | 14 +++- crates/burn-core/src/nn/leaky_relu.rs | 25 ++++++ crates/burn-core/src/nn/linear.rs | 17 ++++- .../src/nn/loss/binary_cross_entropy.rs | 36 ++++++++- crates/burn-core/src/nn/loss/cross_entropy.rs | 50 +++++++++++- crates/burn-core/src/nn/loss/huber.rs | 61 ++++++++++----- crates/burn-core/src/nn/loss/mse.rs | 30 ++++---- crates/burn-core/src/nn/norm/batch.rs | 17 ++++- crates/burn-core/src/nn/norm/group.rs | 42 +++++++++- crates/burn-core/src/nn/norm/instance.rs | 38 +++++++++- crates/burn-core/src/nn/norm/layer.rs | 20 ++++- crates/burn-core/src/nn/norm/rms.rs | 32 +++++++- .../src/nn/pool/adaptive_avg_pool1d.rs | 33 +++++++- .../src/nn/pool/adaptive_avg_pool2d.rs | 35 ++++++++- crates/burn-core/src/nn/pool/avg_pool1d.rs | 47 +++++++++++- crates/burn-core/src/nn/pool/avg_pool2d.rs | 48 +++++++++++- crates/burn-core/src/nn/pool/max_pool1d.rs | 48 +++++++++++- crates/burn-core/src/nn/pool/max_pool2d.rs | 48 +++++++++++- crates/burn-core/src/nn/pos_encoding.rs | 45 ++++++++++- crates/burn-core/src/nn/prelu.rs | 42 +++++++++- crates/burn-core/src/nn/relu.rs | 14 +++- crates/burn-core/src/nn/rnn/gru.rs | 45 ++++++++++- crates/burn-core/src/nn/rnn/lstm.rs | 71 ++++++++++++++++- crates/burn-core/src/nn/rope_encoding.rs | 43 ++++++++++- crates/burn-core/src/nn/swiglu.rs | 33 +++++++- crates/burn-core/src/nn/tanh.rs | 14 +++- .../burn-core/src/nn/transformer/decoder.rs | 76 +++++++++++++++++-- .../burn-core/src/nn/transformer/encoder.rs | 73 ++++++++++++++++-- crates/burn-core/src/nn/transformer/pwff.rs | 50 ++++++++++-- crates/burn-core/src/nn/unfold.rs | 59 +++++++++++--- .../pytorch-tests/tests/linear/mod.rs | 2 +- .../src/burn/node/conv_transpose_2d.rs | 1 + crates/burn-import/src/burn/node/prelu.rs | 3 +- crates/burn-train/src/learner/train_val.rs | 2 +- 41 files changed, 1340 insertions(+), 153 deletions(-) diff --git a/crates/burn-core/src/nn/attention/mha.rs b/crates/burn-core/src/nn/attention/mha.rs index ed6eb4923..ad754bdf4 100644 --- a/crates/burn-core/src/nn/attention/mha.rs +++ b/crates/burn-core/src/nn/attention/mha.rs @@ -1,10 +1,10 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::nn::cache::TensorCache; use crate::nn::Initializer; use crate::{ config::Config, - module::Module, nn, tensor::{activation, backend::Backend, Bool, Tensor}, }; @@ -53,17 +53,49 @@ pub struct MultiHeadAttentionConfig { /// /// Should be created with [MultiHeadAttentionConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct MultiHeadAttention { - query: nn::Linear, - key: nn::Linear, - value: nn::Linear, - output: nn::Linear, - dropout: nn::Dropout, - activation: nn::Gelu, - n_heads: usize, - d_k: usize, - min_float: f64, - quiet_softmax: bool, + /// Linear layer to transform the input features into the query space. + pub query: nn::Linear, + /// Linear layer to transform the input features into the key space. + pub key: nn::Linear, + /// Linear layer to transform the input features into the value space. + pub value: nn::Linear, + /// Linear layer to transform the output features back to the original space. + pub output: nn::Linear, + /// Dropout layer. + pub dropout: nn::Dropout, + /// Activation function. + pub activation: nn::Gelu, + /// The size of each linear layer. + pub d_model: usize, + /// The number of heads. + pub n_heads: usize, + /// Size of the key and query vectors. + pub d_k: usize, + /// Minimum value a float can take. + pub min_float: f64, + /// Use "quiet softmax" instead of regular softmax. + pub quiet_softmax: bool, +} + +impl ModuleDisplay for MultiHeadAttention { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("d_model", &self.d_model) + .add("n_heads", &self.n_heads) + .add("d_k", &self.d_k) + .add("dropout", &self.dropout.prob) + .add("min_float", &self.min_float) + .add("quiet_softmax", &self.quiet_softmax) + .optional() + } } /// [Multihead attention](MultiHeadAttention) forward pass input argument. @@ -99,6 +131,7 @@ impl MultiHeadAttentionConfig { d_k: self.d_model / self.n_heads, min_float: self.min_float, quiet_softmax: self.quiet_softmax, + d_model: self.d_model, } } } @@ -478,4 +511,16 @@ mod tests { .into_data() .assert_approx_eq(&output_2.into_data(), 3); } + + #[test] + fn display() { + let config = MultiHeadAttentionConfig::new(2, 4); + let mha = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", mha), + "MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \ + dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}" + ); + } } diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index 6d6b98100..0b64eab32 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -50,11 +50,16 @@ pub struct Conv1d { pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: Ignored, + /// Stride of the convolution. + pub stride: usize, + /// Size of the kernel. + pub kernel_size: usize, + /// Spacing between kernel elements. + pub dilation: usize, + /// Controls the connections between input and output channels. + pub groups: usize, + /// Padding configuration. + pub padding: Ignored, } impl ModuleDisplay for Conv1d { @@ -169,4 +174,15 @@ mod tests { .to_data() .assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3); } + + #[test] + fn display() { + let config = Conv1dConfig::new(5, 5, 5); + let conv = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", conv), + "Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}" + ); + } } diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index e5fc10adb..bf31fd966 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -52,11 +52,16 @@ pub struct Conv2d { pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: Ignored, + /// Stride of the convolution. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + pub groups: usize, + /// The padding configuration. + pub padding: Ignored, } impl Conv2dConfig { @@ -214,4 +219,15 @@ mod tests { assert_eq!(config.initializer, init); } + + #[test] + fn display() { + let config = Conv2dConfig::new([5, 1], [5, 5]); + let conv = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", conv), + "Conv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}" + ); + } } diff --git a/crates/burn-core/src/nn/conv/conv_transpose1d.rs b/crates/burn-core/src/nn/conv/conv_transpose1d.rs index c2d21309c..359847471 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose1d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose1d.rs @@ -1,7 +1,12 @@ +use alloc::format; + use crate as burn; use crate::config::Config; +use crate::module::Content; +use crate::module::DisplaySettings; use crate::module::Module; +use crate::module::ModuleDisplay; use crate::module::Param; use crate::nn::conv::checks; use crate::nn::Initializer; @@ -45,17 +50,46 @@ pub struct ConvTranspose1dConfig { /// Applies a 1D transposed convolution over input tensors. #[derive(Module, Debug)] +#[module(custom_display)] pub struct ConvTranspose1d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: usize, - padding_out: usize, + /// Stride of the convolution. + pub stride: usize, + /// Size of the kernel. + pub kernel_size: usize, + /// Spacing between kernel elements. + pub dilation: usize, + /// Controls the connections between input and output channels. + pub groups: usize, + /// The padding configuration. + pub padding: usize, + /// The padding output configuration. + pub padding_out: usize, + /// The number of channels. + pub channels: [usize; 2], +} + +impl ModuleDisplay for ConvTranspose1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("channels", &format!("{:?}", &self.channels)) + .add("stride", &self.stride) + .add("kernel_size", &self.kernel_size) + .add("dilation", &self.dilation) + .add("groups", &self.groups) + .add("padding", &self.padding) + .add("padding_out", &self.padding_out) + .optional() + } } impl ConvTranspose1dConfig { @@ -91,6 +125,7 @@ impl ConvTranspose1dConfig { groups: self.groups, padding: self.padding, padding_out: self.padding_out, + channels: self.channels, } } } @@ -150,4 +185,15 @@ mod tests { .to_data() .assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3); } + + #[test] + fn display() { + let config = ConvTranspose1dConfig::new([5, 2], 5); + let conv = config.init::(&Default::default()); + + assert_eq!( + format!("{}", conv), + "ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}" + ); + } } diff --git a/crates/burn-core/src/nn/conv/conv_transpose2d.rs b/crates/burn-core/src/nn/conv/conv_transpose2d.rs index f289b5abf..7fa3bd788 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose2d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose2d.rs @@ -1,7 +1,12 @@ +use alloc::format; + use crate as burn; use crate::config::Config; +use crate::module::Content; +use crate::module::DisplaySettings; use crate::module::Module; +use crate::module::ModuleDisplay; use crate::module::Param; use crate::nn::conv::checks; use crate::nn::Initializer; @@ -45,17 +50,46 @@ pub struct ConvTranspose2dConfig { /// Applies a 2D transposed convolution over input tensors. #[derive(Module, Debug)] +#[module(custom_display)] pub struct ConvTranspose2d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: [usize; 2], - padding_out: [usize; 2], + /// Stride of the convolution. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + pub groups: usize, + /// Padding configuration. + pub padding: [usize; 2], + /// Padding output configuration. + pub padding_out: [usize; 2], + /// Number of channels. + pub channels: [usize; 2], +} + +impl ModuleDisplay for ConvTranspose2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("channels", &format!("{:?}", &self.channels)) + .add("stride", &format!("{:?}", &self.stride)) + .add("kernel_size", &format!("{:?}", &self.kernel_size)) + .add("dilation", &format!("{:?}", &self.dilation)) + .add("groups", &self.groups) + .add("padding", &format!("{:?}", &self.padding)) + .add("padding_out", &format!("{:?}", &self.padding_out)) + .optional() + } } impl ConvTranspose2dConfig { @@ -92,6 +126,7 @@ impl ConvTranspose2dConfig { groups: self.groups, padding: self.padding, padding_out: self.padding_out, + channels: self.channels, } } } @@ -152,4 +187,15 @@ mod tests { .to_data() .assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3); } + + #[test] + fn display() { + let config = ConvTranspose2dConfig::new([5, 2], [5, 5]); + let conv = config.init::(&Default::default()); + + assert_eq!( + format!("{}", conv), + "ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}" + ); + } } diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index b4bee8d61..d03e95c1f 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::{DisplaySettings, Module, ModuleDisplay}; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::{Distribution, Tensor}; @@ -23,7 +23,8 @@ pub struct DropoutConfig { #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Dropout { - prob: f64, + /// The probability of randomly zeroes some elements of the input tensor during training. + pub prob: f64, } impl DropoutConfig { @@ -62,7 +63,7 @@ impl ModuleDisplay for Dropout { .optional() } - fn custom_content(&self, content: crate::module::Content) -> Option { + fn custom_content(&self, content: Content) -> Option { content.add("prob", &self.prob).optional() } } @@ -99,4 +100,12 @@ mod tests { assert_eq!(tensor.to_data(), output.to_data()); } + + #[test] + fn display() { + let config = DropoutConfig::new(0.5); + let layer = config.init(); + + assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}"); + } } diff --git a/crates/burn-core/src/nn/embedding.rs b/crates/burn-core/src/nn/embedding.rs index 2c6ea7c59..9c9658653 100644 --- a/crates/burn-core/src/nn/embedding.rs +++ b/crates/burn-core/src/nn/embedding.rs @@ -4,6 +4,7 @@ use super::Initializer; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Int; use crate::tensor::Tensor; @@ -26,12 +27,29 @@ pub struct EmbeddingConfig { /// /// Should be created with [EmbeddingConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Embedding { /// The learnable weights of the module of shape `[n_embedding, d_model]` initialized /// from a normal distribution `N(0, 1)`. pub weight: Param>, } +impl ModuleDisplay for Embedding { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [n_embedding, d_model] = self.weight.shape().dims; + content + .add("n_embedding", &n_embedding) + .add("d_model", &d_model) + .optional() + } +} + impl EmbeddingConfig { /// Initialize a new [embedding](Embedding) module. pub fn init(&self, device: &B::Device) -> Embedding { @@ -100,4 +118,15 @@ mod tests { .to_data() .assert_approx_eq(&TensorData::zeros::(embed.weight.shape()), 3); } + + #[test] + fn display() { + let config = EmbeddingConfig::new(100, 10); + let embed = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", embed), + "Embedding {n_embedding: 100, d_model: 10, params: 1000}" + ); + } } diff --git a/crates/burn-core/src/nn/gelu.rs b/crates/burn-core/src/nn/gelu.rs index 421f83452..f56bc29f8 100644 --- a/crates/burn-core/src/nn/gelu.rs +++ b/crates/burn-core/src/nn/gelu.rs @@ -7,7 +7,7 @@ use crate::tensor::Tensor; /// Applies the Gaussian Error Linear Units function element-wise. /// See also [gelu](burn::tensor::activation::gelu) #[derive(Module, Clone, Debug, Default)] -pub struct Gelu {} +pub struct Gelu; impl Gelu { /// Create the module. @@ -25,3 +25,15 @@ impl Gelu { crate::tensor::activation::gelu(input) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Gelu::new(); + + assert_eq!(alloc::format!("{}", layer), "Gelu"); + } +} diff --git a/crates/burn-core/src/nn/leaky_relu.rs b/crates/burn-core/src/nn/leaky_relu.rs index ef3fbba8e..339c8e015 100644 --- a/crates/burn-core/src/nn/leaky_relu.rs +++ b/crates/burn-core/src/nn/leaky_relu.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -10,6 +11,7 @@ use crate::tensor::activation::leaky_relu; /// /// Should be created with [LeakyReluConfig](LeakyReluConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct LeakyRelu { /// The negative slope. pub negative_slope: f64, @@ -30,6 +32,20 @@ impl LeakyReluConfig { } } +impl ModuleDisplay for LeakyRelu { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("negative_slope", &self.negative_slope) + .optional() + } +} + impl LeakyRelu { /// Forward pass for the Leaky ReLu layer. /// @@ -92,4 +108,13 @@ mod tests { let actual_output = model.forward(input_data); actual_output.to_data().assert_approx_eq(&expected, 4) } + + #[test] + fn display() { + let config = LeakyReluConfig::new().init(); + assert_eq!( + alloc::format!("{}", config), + "LeakyRelu {negative_slope: 0.01}" + ); + } } diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 6fe212b6d..c90074b6d 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -1,10 +1,8 @@ use crate as burn; -use crate::module::DisplaySettings; -use crate::module::ModuleDisplay; use crate::config::Config; -use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::{backend::Backend, Tensor}; use super::Initializer; @@ -93,7 +91,7 @@ impl ModuleDisplay for Linear { .optional() } - fn custom_content(&self, content: crate::module::Content) -> Option { + fn custom_content(&self, content: Content) -> Option { let [d_input, d_output] = self.weight.shape().dims; content .add("d_input", &d_input) @@ -196,4 +194,15 @@ mod tests { assert_eq!(result_1d.into_data(), result_2d.into_data()); } + + #[test] + fn display() { + let config = LinearConfig::new(3, 5); + let linear = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", linear), + "Linear {d_input: 3, d_output: 5, bias: true, params: 20}" + ); + } } diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index b7570a600..f645c84fd 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::activation::log_sigmoid; use crate::tensor::{backend::Backend, Int, Tensor}; @@ -59,11 +60,30 @@ impl BinaryCrossEntropyLossConfig { /// /// Should be created using [BinaryCrossEntropyLossConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct BinaryCrossEntropyLoss { /// Weights for cross-entropy. pub weights: Option>, - smoothing: Option, - logits: bool, + /// Label smoothing alpha. + pub smoothing: Option, + /// Treat the inputs as logits + pub logits: bool, +} + +impl ModuleDisplay for BinaryCrossEntropyLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("weights", &self.weights) + .add("smoothing", &self.smoothing) + .add("logits", &self.logits) + .optional() + } } impl BinaryCrossEntropyLoss { @@ -368,4 +388,16 @@ mod tests { .init(&device) .forward(logits, targets); } + + #[test] + fn display() { + let config = + BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9])); + let loss = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", loss), + "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}" + ); + } } diff --git a/crates/burn-core/src/nn/loss/cross_entropy.rs b/crates/burn-core/src/nn/loss/cross_entropy.rs index fee040a31..be9484f31 100644 --- a/crates/burn-core/src/nn/loss/cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/cross_entropy.rs @@ -1,8 +1,10 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::activation::log_softmax; use crate::tensor::{backend::Backend, Bool, Int, Tensor}; use crate::{config::Config, module::Module}; +use alloc::string::ToString; use alloc::vec; use alloc::vec::Vec; @@ -29,7 +31,7 @@ pub struct CrossEntropyLossConfig { /// Alpha = 0 would be the same as default. pub smoothing: Option, - /// Create cross-entropy with probabilities as input instead of logits. + /// Create cross-entropy with probabilities as input instead of logits. /// #[config(default = true)] pub logits: bool, @@ -71,12 +73,39 @@ impl CrossEntropyLossConfig { /// /// Should be created using [CrossEntropyLossConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct CrossEntropyLoss { - pad_tokens: Option>, + /// Pad tokens to ignore in the loss calculation. + pub pad_tokens: Option>, /// Weights for cross-entropy. pub weights: Option>, - smoothing: Option, - logits: bool, + /// Label smoothing factor. + pub smoothing: Option, + /// Use logits as input. + pub logits: bool, +} + +impl ModuleDisplay for CrossEntropyLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens { + alloc::format!("Vec<0..{}>", pad_tokens.len()) + } else { + "None".to_string() + }; + + content + .add("pad_tokens", &pad_tokens) + .add("weights", &self.weights) + .add("smoothing", &self.smoothing) + .add("logits", &self.logits) + .optional() + } } impl CrossEntropyLoss { @@ -406,4 +435,17 @@ mod tests { loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); } + + #[test] + fn display() { + let config = CrossEntropyLossConfig::new() + .with_weights(Some(alloc::vec![3., 7., 0.9])) + .with_smoothing(Some(0.5)); + let loss = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", loss), + "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}" + ); + } } diff --git a/crates/burn-core/src/nn/loss/huber.rs b/crates/burn-core/src/nn/loss/huber.rs index 5285e403f..8b227b0a4 100644 --- a/crates/burn-core/src/nn/loss/huber.rs +++ b/crates/burn-core/src/nn/loss/huber.rs @@ -1,9 +1,9 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; use crate::{config::Config, module::Module}; -use core::marker::PhantomData; use super::Reduction; @@ -16,15 +16,11 @@ pub struct HuberLossConfig { impl HuberLossConfig { /// Initialize [Huber loss](HuberLoss). - pub fn init(&self, device: &B::Device) -> HuberLoss { - // device is not needed as of now, but we might want to prepare some data on it - // and its consistent with other loss functions - let _ = device; + pub fn init(&self) -> HuberLoss { self.assertions(); HuberLoss { delta: self.delta, lin_bias: self.delta * self.delta * 0.5, - _backend: PhantomData, } } @@ -52,14 +48,31 @@ impl HuberLossConfig { /// This loss function is less sensitive to outliers than the mean squared error loss. /// /// See also: -#[derive(Module, Debug)] -pub struct HuberLoss { - delta: f32, - lin_bias: f32, // delta * delta * 0.5 precomputed - _backend: PhantomData, +#[derive(Module, Debug, Clone)] +#[module(custom_display)] +pub struct HuberLoss { + /// The bound where the Huber loss function changes from quadratic to linear behaviour. + pub delta: f32, + /// Precomputed value for the linear bias. + pub lin_bias: f32, // delta * delta * 0.5 precomputed } -impl HuberLoss { +impl ModuleDisplay for HuberLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("delta", &self.delta) + .add("lin_bias", &self.lin_bias) + .optional() + } +} + +impl HuberLoss { /// Compute the loss element-wise for the predictions and targets, then reduce /// to a single loss value. /// @@ -70,7 +83,7 @@ impl HuberLoss { /// - predictions: \[...dims\] /// - targets: \[...dims\] /// - output: \[1\] - pub fn forward( + pub fn forward( &self, predictions: Tensor, targets: Tensor, @@ -89,7 +102,7 @@ impl HuberLoss { /// - predictions: [...dims] /// - targets: [...dims] /// - output: [...dims] - pub fn forward_no_reduction( + pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, @@ -103,7 +116,10 @@ impl HuberLoss { /// /// - residuals: [...dims] /// - output: [...dims] - pub fn forward_residuals(&self, residuals: Tensor) -> Tensor { + pub fn forward_residuals( + &self, + residuals: Tensor, + ) -> Tensor { let is_large = residuals.clone().abs().greater_elem(self.delta); // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the // `sign()` function, in general, suffers from a jump at 0. @@ -138,7 +154,7 @@ mod tests { let predict = TestTensor::<1>::from_data(predict, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let huber = HuberLossConfig::new(0.5).init(&device); + let huber = HuberLossConfig::new(0.5).init(); let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum); let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto); @@ -166,7 +182,7 @@ mod tests { let predict = TestAutodiffTensor::from_data(predict, &device).require_grad(); let targets = TestAutodiffTensor::from_data(targets, &device); - let loss = HuberLossConfig::new(0.5).init(&device); + let loss = HuberLossConfig::new(0.5).init(); let loss = loss.forward_no_reduction(predict.clone(), targets); let grads = loss.backward(); @@ -175,4 +191,15 @@ mod tests { let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]); grads_predict.to_data().assert_approx_eq(&expected, 3); } + + #[test] + fn display() { + let config = HuberLossConfig::new(0.5); + let loss = config.init(); + + assert_eq!( + alloc::format!("{}", loss), + "HuberLoss {delta: 0.5, lin_bias: 0.125}" + ); + } } diff --git a/crates/burn-core/src/nn/loss/mse.rs b/crates/burn-core/src/nn/loss/mse.rs index b575aef01..56a79f450 100644 --- a/crates/burn-core/src/nn/loss/mse.rs +++ b/crates/burn-core/src/nn/loss/mse.rs @@ -1,26 +1,24 @@ -use crate::nn::loss::reduction::Reduction; -use core::marker::PhantomData; +use crate as burn; +use crate::nn::loss::reduction::Reduction; + +use crate::module::Module; use crate::tensor::{backend::Backend, Tensor}; /// Calculate the mean squared error loss from the input logits and the targets. -#[derive(Clone, Debug)] -pub struct MseLoss { - backend: PhantomData, -} +#[derive(Module, Clone, Debug)] +pub struct MseLoss; -impl Default for MseLoss { +impl Default for MseLoss { fn default() -> Self { Self::new() } } -impl MseLoss { +impl MseLoss { /// Create the criterion. pub fn new() -> Self { - Self { - backend: PhantomData, - } + Self } /// Compute the criterion on the input tensor. @@ -29,7 +27,7 @@ impl MseLoss { /// /// - logits: [batch_size, num_targets] /// - targets: [batch_size, num_targets] - pub fn forward( + pub fn forward( &self, logits: Tensor, targets: Tensor, @@ -43,7 +41,7 @@ impl MseLoss { } /// Compute the criterion on the input tensor without reducing. - pub fn forward_no_reduction( + pub fn forward_no_reduction( &self, logits: Tensor, targets: Tensor, @@ -85,4 +83,10 @@ mod tests { let expected = TensorData::from([6.0]); loss_sum.into_data().assert_eq(&expected, false); } + + #[test] + fn display() { + let loss = MseLoss::new(); + assert_eq!(alloc::format!("{}", loss), "MseLoss"); + } } diff --git a/crates/burn-core/src/nn/norm/batch.rs b/crates/burn-core/src/nn/norm/batch.rs index 8519636d3..4a0cb7d2c 100644 --- a/crates/burn-core/src/nn/norm/batch.rs +++ b/crates/burn-core/src/nn/norm/batch.rs @@ -44,8 +44,10 @@ pub struct BatchNorm { pub running_mean: RunningState>, /// The running variance. pub running_var: RunningState>, - momentum: f64, - epsilon: f64, + /// Momentum used to update the metrics. + pub momentum: f64, + /// A value required for numerical stability. + pub epsilon: f64, } impl BatchNormConfig { @@ -424,4 +426,15 @@ mod tests_2d { device, ) } + + #[test] + fn display() { + let batch_norm = + BatchNormConfig::new(3).init::(&Default::default()); + + assert_eq!( + format!("{}", batch_norm), + "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}" + ); + } } diff --git a/crates/burn-core/src/nn/norm/group.rs b/crates/burn-core/src/nn/norm/group.rs index 170a32a27..374156500 100644 --- a/crates/burn-core/src/nn/norm/group.rs +++ b/crates/burn-core/src/nn/norm/group.rs @@ -4,6 +4,7 @@ use crate::nn::Initializer; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -36,16 +37,37 @@ pub struct GroupNormConfig { /// /// Should be created using [GroupNormConfig](GroupNormConfig). #[derive(Module, Debug)] +#[module(custom_display)] pub struct GroupNorm { /// The learnable weight pub gamma: Option>>, /// The learnable bias pub beta: Option>>, + /// The number of groups to separate the channels into + pub num_groups: usize, + /// The number of channels expected in the input + pub num_channels: usize, + /// A value required for numerical stability + pub epsilon: f64, + /// A boolean value that when set to `true`, this module has learnable + pub affine: bool, +} - pub(crate) num_groups: usize, - pub(crate) num_channels: usize, - pub(crate) epsilon: f64, - pub(crate) affine: bool, +impl ModuleDisplay for GroupNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("num_groups", &self.num_groups) + .add("num_channels", &self.num_channels) + .add("epsilon", &self.epsilon) + .add("affine", &self.affine) + .optional() + } } impl GroupNormConfig { @@ -169,6 +191,7 @@ mod tests { use super::*; use crate::tensor::TensorData; use crate::TestBackend; + use alloc::format; #[test] fn group_norm_forward_affine_false() { @@ -292,4 +315,15 @@ mod tests { ]); output.to_data().assert_approx_eq(&expected, 3); } + + #[test] + fn display() { + let config = GroupNormConfig::new(3, 6); + let group_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", group_norm), + "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}" + ); + } } diff --git a/crates/burn-core/src/nn/norm/instance.rs b/crates/burn-core/src/nn/norm/instance.rs index 717c9b74f..3ee5ef4c0 100644 --- a/crates/burn-core/src/nn/norm/instance.rs +++ b/crates/burn-core/src/nn/norm/instance.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Module, Param}; use crate::nn::norm::group_norm; use crate::nn::Initializer; @@ -25,15 +26,34 @@ pub struct InstanceNormConfig { /// /// Should be created using [InstanceNormConfig](InstanceNormConfig). #[derive(Module, Debug)] +#[module(custom_display)] pub struct InstanceNorm { /// The learnable weight pub gamma: Option>>, /// The learnable bias pub beta: Option>>, + /// The number of channels expected in the input + pub num_channels: usize, + /// A value required for numerical stability + pub epsilon: f64, + /// A boolean value that when set to `true`, this module has learnable + pub affine: bool, +} - num_channels: usize, - epsilon: f64, - affine: bool, +impl ModuleDisplay for InstanceNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("num_channels", &self.num_channels) + .add("epsilon", &self.epsilon) + .add("affine", &self.affine) + .optional() + } } impl InstanceNormConfig { @@ -83,6 +103,7 @@ mod tests { use super::*; use crate::tensor::TensorData; use crate::TestBackend; + use alloc::format; #[test] fn instance_norm_forward_affine_false() { @@ -187,4 +208,15 @@ mod tests { ]); output.to_data().assert_approx_eq(&expected, 3); } + + #[test] + fn display() { + let config = InstanceNormConfig::new(6); + let instance_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", instance_norm), + "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}" + ); + } } diff --git a/crates/burn-core/src/nn/norm/layer.rs b/crates/burn-core/src/nn/norm/layer.rs index a856bac8b..ea196906c 100644 --- a/crates/burn-core/src/nn/norm/layer.rs +++ b/crates/burn-core/src/nn/norm/layer.rs @@ -1,5 +1,6 @@ use crate as burn; use crate::config::Config; +use crate::module::Content; use crate::module::DisplaySettings; use crate::module::Module; use crate::module::ModuleDisplay; @@ -33,9 +34,10 @@ pub struct LayerNormConfig { #[module(custom_display)] pub struct LayerNorm { /// The learnable weight. - gamma: Param>, + pub gamma: Param>, /// The learnable bias. - beta: Param>, + pub beta: Param>, + /// A value required for numerical stability. epsilon: f64, } @@ -80,7 +82,7 @@ impl ModuleDisplay for LayerNorm { .optional() } - fn custom_content(&self, content: crate::module::Content) -> Option { + fn custom_content(&self, content: Content) -> Option { let [d_model] = self.gamma.shape().dims; content .add("d_model", &d_model) @@ -93,6 +95,7 @@ impl ModuleDisplay for LayerNorm { mod tests { use super::*; use crate::tensor::TensorData; + use alloc::format; #[cfg(feature = "std")] use crate::{TestAutodiffBackend, TestBackend}; @@ -157,4 +160,15 @@ mod tests { let expected = TensorData::zeros::(tensor_2_grad.shape()); tensor_2_grad.to_data().assert_approx_eq(&expected, 3); } + + #[test] + fn display() { + let config = LayerNormConfig::new(6); + let layer_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", layer_norm), + "LayerNorm {d_model: 6, epsilon: 0.00001, params: 12}" + ); + } } diff --git a/crates/burn-core/src/nn/norm/rms.rs b/crates/burn-core/src/nn/norm/rms.rs index f9eae5bf2..e054c4f19 100644 --- a/crates/burn-core/src/nn/norm/rms.rs +++ b/crates/burn-core/src/nn/norm/rms.rs @@ -3,6 +3,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -48,11 +49,12 @@ impl RmsNormConfig { /// /// Should be created using the [RmsNormConfig](RmsNormConfig) configuration. #[derive(Module, Debug)] +#[module(custom_display)] pub struct RmsNorm { /// The learnable parameter to scale the normalized tensor pub gamma: Param>, /// A value required for numerical stability - epsilon: f64, + pub epsilon: f64, } impl RmsNorm { @@ -71,11 +73,28 @@ impl RmsNorm { } } +impl ModuleDisplay for RmsNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_model] = self.gamma.shape().dims; + content + .add("d_model", &d_model) + .add("epsilon", &self.epsilon) + .optional() + } +} + #[cfg(test)] mod tests { use super::*; use crate::tensor::TensorData; use crate::TestBackend; + use alloc::format; #[test] fn rms_norm_forward() { @@ -95,4 +114,15 @@ mod tests { ]); output.to_data().assert_approx_eq(&expected, 4); } + + #[test] + fn display() { + let config = RmsNormConfig::new(6); + let layer_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", layer_norm), + "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}" + ); + } } diff --git a/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs b/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs index dd2c1d33c..5322fa600 100644 --- a/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -18,8 +19,22 @@ pub struct AdaptiveAvgPool1dConfig { /// /// Should be created with [AdaptiveAvgPool1dConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AdaptiveAvgPool1d { - output_size: usize, + /// The size of the output. + pub output_size: usize, +} + +impl ModuleDisplay for AdaptiveAvgPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content.add("output_size", &self.output_size).optional() + } } impl AdaptiveAvgPool1dConfig { @@ -44,3 +59,19 @@ impl AdaptiveAvgPool1d { adaptive_avg_pool1d(input, self.output_size) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AdaptiveAvgPool1dConfig::new(3); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AdaptiveAvgPool1d {output_size: 3}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs b/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs index 8d4d55d42..1f63fb8c9 100644 --- a/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -18,8 +19,24 @@ pub struct AdaptiveAvgPool2dConfig { /// /// Should be created with [AdaptiveAvgPool2dConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AdaptiveAvgPool2d { - output_size: [usize; 2], + /// The size of the output. + pub output_size: [usize; 2], +} + +impl ModuleDisplay for AdaptiveAvgPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let output_size = alloc::format!("{:?}", self.output_size); + + content.add("output_size", &output_size).optional() + } } impl AdaptiveAvgPool2dConfig { @@ -44,3 +61,19 @@ impl AdaptiveAvgPool2d { adaptive_avg_pool2d(input, self.output_size) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AdaptiveAvgPool2dConfig::new([3, 3]); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AdaptiveAvgPool2d {output_size: [3, 3]}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index 5787cc5e2..949160fd5 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; @@ -40,11 +41,33 @@ pub struct AvgPool1dConfig { /// [Issue 636](https://github.com/tracel-ai/burn/issues/636) #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AvgPool1d { - stride: usize, - kernel_size: usize, - padding: Ignored, - count_include_pad: bool, + /// The stride. + pub stride: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The padding configuration. + pub padding: Ignored, + /// If the padding is counted in the denominator when computing the average. + pub count_include_pad: bool, +} + +impl ModuleDisplay for AvgPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &self.kernel_size) + .add("stride", &self.stride) + .add("padding", &self.padding) + .add("count_include_pad", &self.count_include_pad) + .optional() + } } impl AvgPool1dConfig { @@ -83,3 +106,19 @@ impl AvgPool1d { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AvgPool1dConfig::new(3); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AvgPool1d {kernel_size: 3, stride: 1, padding: Valid, count_include_pad: true}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index 00bf712f8..6c6ffc87e 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; @@ -39,11 +40,33 @@ pub struct AvgPool2dConfig { /// TODO: Add support for `count_include_pad=False`, see /// [Issue 636](https://github.com/tracel-ai/burn/issues/636) #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AvgPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: Ignored, - count_include_pad: bool, + /// Stride of the pooling. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Padding configuration. + pub padding: Ignored, + /// If the padding is counted in the denominator when computing the average. + pub count_include_pad: bool, +} + +impl ModuleDisplay for AvgPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("count_include_pad", &self.count_include_pad) + .optional() + } } impl AvgPool2dConfig { @@ -82,3 +105,20 @@ impl AvgPool2d { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AvgPool2dConfig::new([3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AvgPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, count_include_pad: true}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index 040a7a102..5be363e90 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; @@ -28,11 +29,33 @@ pub struct MaxPool1dConfig { /// /// Should be created with [MaxPool1dConfig](MaxPool1dConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct MaxPool1d { - stride: usize, - kernel_size: usize, - padding: Ignored, - dilation: usize, + /// The stride. + pub stride: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: usize, +} + +impl ModuleDisplay for MaxPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &self.kernel_size) + .add("stride", &self.stride) + .add("padding", &self.padding) + .add("dilation", &self.dilation) + .optional() + } } impl MaxPool1dConfig { @@ -65,3 +88,20 @@ impl MaxPool1d { max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = MaxPool1dConfig::new(3); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index 552cde9b3..ab9c60d27 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; @@ -28,11 +29,33 @@ pub struct MaxPool2dConfig { /// /// Should be created with [MaxPool2dConfig](MaxPool2dConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct MaxPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: Ignored, - dilation: [usize; 2], + /// The strides. + pub stride: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: [usize; 2], +} + +impl ModuleDisplay for MaxPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("dilation", &alloc::format!("{:?}", &self.dilation)) + .optional() + } } impl MaxPool2dConfig { @@ -65,3 +88,20 @@ impl MaxPool2d { max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = MaxPool2dConfig::new([3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}" + ); + } +} diff --git a/crates/burn-core/src/nn/pos_encoding.rs b/crates/burn-core/src/nn/pos_encoding.rs index 41b7c731e..ff1db2cfc 100644 --- a/crates/burn-core/src/nn/pos_encoding.rs +++ b/crates/burn-core/src/nn/pos_encoding.rs @@ -2,7 +2,8 @@ use alloc::vec::Vec; use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; + use crate::tensor::backend::Backend; use crate::tensor::Tensor; use crate::tensor::TensorData; @@ -40,8 +41,31 @@ pub struct PositionalEncodingConfig { /// /// Should be created using [PositionalEncodingConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct PositionalEncoding { - sinusoids: Tensor, + /// The sinusoids used to add positional information to the input embeddings. + pub sinusoids: Tensor, + /// The maximum sequence size to use. + pub max_sequence_size: usize, + /// Max time scale to use. + pub max_timescale: usize, +} + +impl ModuleDisplay for PositionalEncoding { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [_, _, d_model] = self.sinusoids.shape().dims; + content + .add("d_model", &d_model) + .add("max_sequence_size", &self.max_sequence_size) + .add("max_timescale", &self.max_timescale) + .optional() + } } impl PositionalEncodingConfig { @@ -55,7 +79,11 @@ impl PositionalEncodingConfig { ) .unsqueeze::<3>(); - PositionalEncoding { sinusoids } + PositionalEncoding { + sinusoids, + max_sequence_size: self.max_sequence_size, + max_timescale: self.max_timescale, + } } } @@ -245,4 +273,15 @@ mod tests { let input = Tensor::zeros([1, 6_000, d_model], &device); let _output = pe.forward(input); } + + #[test] + fn display() { + let config = PositionalEncodingConfig::new(4); + let pe = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", pe), + "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}" + ); + } } diff --git a/crates/burn-core/src/nn/prelu.rs b/crates/burn-core/src/nn/prelu.rs index f15c96481..6bb6c32f7 100644 --- a/crates/burn-core/src/nn/prelu.rs +++ b/crates/burn-core/src/nn/prelu.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -9,11 +9,33 @@ use crate::tensor::Tensor; /// /// Should be created using [PReluConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct PRelu { /// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must /// be the same as number of channels in the input tensor pub alpha: Param>, + + /// Alpha value for the PRelu layer + pub alpha_value: f64, } + +impl ModuleDisplay for PRelu { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [num_parameters] = self.alpha.shape().dims; + + content + .add("num_parameters", &num_parameters) + .add("alpha_value", &self.alpha_value) + .optional() + } +} + /// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init). #[derive(Config, Debug)] pub struct PReluConfig { @@ -24,12 +46,14 @@ pub struct PReluConfig { #[config(default = "0.25")] pub alpha: f64, } + impl PReluConfig { /// Initialize a new [Parametric Relu](PRelu) Layer pub fn init(&self, device: &B::Device) -> PRelu { PRelu { // alpha is a tensor of length num_parameters alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device), + alpha_value: self.alpha, } } } @@ -47,3 +71,19 @@ impl PRelu { crate::tensor::activation::prelu(input, self.alpha.val()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn display() { + let layer = PReluConfig::new().init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}" + ); + } +} diff --git a/crates/burn-core/src/nn/relu.rs b/crates/burn-core/src/nn/relu.rs index 262c39313..67ed033b9 100644 --- a/crates/burn-core/src/nn/relu.rs +++ b/crates/burn-core/src/nn/relu.rs @@ -8,7 +8,7 @@ use crate::tensor::Tensor; /// See also [relu](burn::tensor::activation::relu) /// #[derive(Module, Clone, Debug, Default)] -pub struct Relu {} +pub struct Relu; impl Relu { /// Create the module. @@ -25,3 +25,15 @@ impl Relu { crate::tensor::activation::relu(input) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Relu::new(); + + assert_eq!(alloc::format!("{}", layer), "Relu"); + } +} diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs index 9201b2675..87b5998fc 100644 --- a/crates/burn-core/src/nn/rnn/gru.rs +++ b/crates/burn-core/src/nn/rnn/gru.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::rnn::gate_controller; use crate::nn::Initializer; use crate::tensor::activation; @@ -30,11 +31,35 @@ pub struct GruConfig { /// /// Should be created with [GruConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Gru { - update_gate: GateController, - reset_gate: GateController, - new_gate: GateController, - d_hidden: usize, + /// The update gate controller. + pub update_gate: GateController, + /// The reset gate controller. + pub reset_gate: GateController, + /// The new gate controller. + pub new_gate: GateController, + /// The size of the hidden state. + pub d_hidden: usize, +} + +impl ModuleDisplay for Gru { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, _] = self.update_gate.input_transform.weight.shape().dims; + let bias = self.update_gate.input_transform.bias.is_some(); + + content + .add("d_input", &d_input) + .add("d_hidden", &self.d_hidden) + .add("bias", &bias) + .optional() + } } impl GruConfig { @@ -274,4 +299,16 @@ mod tests { assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); } + + #[test] + fn display() { + let config = GruConfig::new(2, 8, true); + + let layer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}" + ); + } } diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 63ad33937..c1d1b23d4 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::rnn::gate_controller::GateController; use crate::nn::Initializer; use crate::tensor::activation; @@ -43,6 +44,7 @@ pub struct LstmConfig { /// /// Should be created with [LstmConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Lstm { /// The input gate regulates which information to update and store in the cell state at each time step. pub input_gate: GateController, @@ -52,7 +54,27 @@ pub struct Lstm { pub output_gate: GateController, /// The cell gate is used to compute the cell state that stores and carries information through time. pub cell_gate: GateController, - d_hidden: usize, + /// The hidden state of the LSTM. + pub d_hidden: usize, +} + +impl ModuleDisplay for Lstm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, _] = self.input_gate.input_transform.weight.shape().dims; + let bias = self.input_gate.input_transform.bias.is_some(); + + content + .add("d_input", &d_input) + .add("d_hidden", &self.d_hidden) + .add("bias", &bias) + .optional() + } } impl LstmConfig { @@ -195,12 +217,33 @@ pub struct BiLstmConfig { /// /// Should be created with [BiLstmConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct BiLstm { /// LSTM for the forward direction. pub forward: Lstm, /// LSTM for the reverse direction. pub reverse: Lstm, - d_hidden: usize, + /// The size of the hidden state. + pub d_hidden: usize, +} + +impl ModuleDisplay for BiLstm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, _] = self.forward.input_gate.input_transform.weight.shape().dims; + let bias = self.forward.input_gate.input_transform.bias.is_some(); + + content + .add("d_input", &d_input) + .add("d_hidden", &self.d_hidden) + .add("bias", &bias) + .optional() + } } impl BiLstmConfig { @@ -693,4 +736,28 @@ mod tests { .to_data() .assert_approx_eq(&expected_cn_without_init_state, 3); } + + #[test] + fn display_lstm() { + let config = LstmConfig::new(2, 3, true); + + let layer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}" + ); + } + + #[test] + fn display_bilstm() { + let config = BiLstmConfig::new(2, 3, true); + + let layer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}" + ); + } } diff --git a/crates/burn-core/src/nn/rope_encoding.rs b/crates/burn-core/src/nn/rope_encoding.rs index 4351bcbe2..2f9401471 100644 --- a/crates/burn-core/src/nn/rope_encoding.rs +++ b/crates/burn-core/src/nn/rope_encoding.rs @@ -1,6 +1,6 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Int; use crate::tensor::Tensor; @@ -74,7 +74,11 @@ impl RotaryEncodingConfig { .repeat(2, 2) .reshape([self.max_sequence_length, self.d_model, 2]); - RotaryEncoding { freq_complex } + RotaryEncoding { + freq_complex, + max_sequence_length: self.max_sequence_length, + theta: self.theta, + } } } @@ -87,9 +91,31 @@ impl RotaryEncodingConfig { /// /// Should be created using [RotaryEncodingConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct RotaryEncoding { /// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components - freq_complex: Tensor, + pub freq_complex: Tensor, + /// Maximum sequence length of input + pub max_sequence_length: usize, + /// Scaling factor for frequency computation. + pub theta: f32, +} + +impl ModuleDisplay for RotaryEncoding { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [_, _, d_model] = self.freq_complex.shape().dims; + content + .add("d_model", &d_model) + .add("max_sequence_length", &self.max_sequence_length) + .add("theta", &self.theta) + .optional() + } } #[allow(clippy::single_range_in_vec_init)] @@ -238,4 +264,15 @@ mod tests { let input = Tensor::zeros([1, 5, d_model], &device); let _output = pe.forward(input); } + + #[test] + fn display() { + let config = RotaryEncodingConfig::new(10, 4); + let pe = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", pe), + "RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}" + ); + } } diff --git a/crates/burn-core/src/nn/swiglu.rs b/crates/burn-core/src/nn/swiglu.rs index 3dacbae68..2227db545 100644 --- a/crates/burn-core/src/nn/swiglu.rs +++ b/crates/burn-core/src/nn/swiglu.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::activation::silu; use crate::tensor::{backend::Backend, Tensor}; @@ -31,6 +31,7 @@ pub struct SwiGluConfig { /// /// Should be created with [SwiGluConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct SwiGlu { /// The inner linear layer for Swish activation function /// with `d_input` input features and `d_output` output features. @@ -40,6 +41,23 @@ pub struct SwiGlu { pub linear_outer: Linear, } +impl ModuleDisplay for SwiGlu { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, d_output] = self.linear_inner.weight.shape().dims; + content + .add("d_input", &d_input) + .add("d_output", &d_output) + .add("bias", &self.linear_inner.bias.is_some()) + .optional() + } +} + impl SwiGluConfig { /// Initialize a new [SwiGLU](SwiGlu) activation layer. pub fn init(&self, device: &B::Device) -> SwiGlu { @@ -61,7 +79,7 @@ impl SwiGlu { /// /// # Shapes /// - /// - input: `[batch_size, seq_length, d_input]` + /// - input: `[batch_size, seq_length, d_input]` /// - output: `[batch_size, seq_length, d_output]` pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input.clone()); @@ -112,4 +130,15 @@ mod tests { .to_data() .assert_approx_eq(&expected_output.to_data(), 4); } + + #[test] + fn display() { + let config = SwiGluConfig::new(3, 5); + let swiglu = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", swiglu), + "SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}" + ); + } } diff --git a/crates/burn-core/src/nn/tanh.rs b/crates/burn-core/src/nn/tanh.rs index 293da36ce..322ac68bd 100644 --- a/crates/burn-core/src/nn/tanh.rs +++ b/crates/burn-core/src/nn/tanh.rs @@ -7,7 +7,7 @@ use crate::tensor::Tensor; /// Applies the tanh activation function element-wise /// See also [tanh](burn::tensor::activation::tanh) #[derive(Module, Clone, Debug, Default)] -pub struct Tanh {} +pub struct Tanh; impl Tanh { /// Create the module. @@ -24,3 +24,15 @@ impl Tanh { crate::tensor::activation::tanh(input) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Tanh::new(); + + assert_eq!(alloc::format!("{}", layer), "Tanh"); + } +} diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index 85fc50159..7784972c7 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -1,15 +1,15 @@ -use crate::tensor::Bool; use alloc::vec::Vec; +use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; + +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; +use crate::tensor::Bool; use crate::{ self as burn, nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; - -use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ config::Config, - module::Module, nn::{ attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, Dropout, DropoutConfig, LayerNorm, LayerNormConfig, @@ -57,8 +57,51 @@ pub struct TransformerDecoderConfig { /// /// Should be created using [TransformerDecoderConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct TransformerDecoder { - layers: Vec>, + /// Transformer decoder layers. + pub layers: Vec>, + + /// The size of the model. + pub d_model: usize, + + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + + /// The number of attention heads. + pub n_heads: usize, + + /// The number of layers. + pub n_layers: usize, + + /// The dropout rate. Default: 0.1 + pub dropout: f64, + + /// Layer norm will be applied first instead of after the other modules. + pub norm_first: bool, + + /// Use "quiet softmax" instead of regular softmax. + pub quiet_softmax: bool, +} + +impl ModuleDisplay for TransformerDecoder { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("d_model", &self.d_model) + .add("d_ff", &self.d_ff) + .add("n_heads", &self.n_heads) + .add("n_layers", &self.n_layers) + .add("dropout", &self.dropout) + .add("norm_first", &self.norm_first) + .add("quiet_softmax", &self.quiet_softmax) + .optional() + } } impl TransformerDecoderConfig { @@ -68,7 +111,16 @@ impl TransformerDecoderConfig { .map(|_| TransformerDecoderLayer::new(self, device)) .collect::>(); - TransformerDecoder { layers } + TransformerDecoder { + layers, + d_model: self.d_model, + d_ff: self.d_ff, + n_heads: self.n_heads, + n_layers: self.n_layers, + dropout: self.dropout, + norm_first: self.norm_first, + quiet_softmax: self.quiet_softmax, + } } } @@ -473,4 +525,16 @@ mod tests { .into_data() .assert_approx_eq(&output_2.into_data(), 3); } + + #[test] + fn display() { + let config = TransformerDecoderConfig::new(2, 4, 2, 3); + let transformer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", transformer), + "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \ + dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}" + ); + } } diff --git a/crates/burn-core/src/nn/transformer/encoder.rs b/crates/burn-core/src/nn/transformer/encoder.rs index 0eb226a3b..6aea721d3 100644 --- a/crates/burn-core/src/nn/transformer/encoder.rs +++ b/crates/burn-core/src/nn/transformer/encoder.rs @@ -1,15 +1,14 @@ use crate::tensor::Bool; use alloc::vec::Vec; +use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::{ self as burn, nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; - -use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ config::Config, - module::Module, nn::{ attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, Dropout, DropoutConfig, LayerNorm, LayerNormConfig, @@ -57,8 +56,51 @@ pub struct TransformerEncoderConfig { /// /// Should be created using [TransformerEncoderConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct TransformerEncoder { - layers: Vec>, + /// The transformer encoder layers. + pub layers: Vec>, + + /// The size of the model. + pub d_model: usize, + + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + + /// The number of attention heads. + pub n_heads: usize, + + /// The number of layers. + pub n_layers: usize, + + /// The dropout rate. Default: 0.1 + pub dropout: f64, + + /// Layer norm will be applied first instead of after the other modules. + pub norm_first: bool, + + /// Use "quiet softmax" instead of regular softmax. + pub quiet_softmax: bool, +} + +impl ModuleDisplay for TransformerEncoder { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("d_model", &self.d_model) + .add("d_ff", &self.d_ff) + .add("n_heads", &self.n_heads) + .add("n_layers", &self.n_layers) + .add("dropout", &self.dropout) + .add("norm_first", &self.norm_first) + .add("quiet_softmax", &self.quiet_softmax) + .optional() + } } /// [Transformer Encoder](TransformerEncoder) forward pass input argument. @@ -98,7 +140,16 @@ impl TransformerEncoderConfig { .map(|_| TransformerEncoderLayer::new(self, device)) .collect::>(); - TransformerEncoder { layers } + TransformerEncoder { + layers, + d_model: self.d_model, + d_ff: self.d_ff, + n_heads: self.n_heads, + n_layers: self.n_layers, + dropout: self.dropout, + norm_first: self.norm_first, + quiet_softmax: self.quiet_softmax, + } } } @@ -392,4 +443,16 @@ mod tests { .into_data() .assert_approx_eq(&output_2.into_data(), 3); } + + #[test] + fn display() { + let config = TransformerEncoderConfig::new(2, 4, 2, 3); + let transformer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", transformer), + "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \ + n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}" + ); + } } diff --git a/crates/burn-core/src/nn/transformer/pwff.rs b/crates/burn-core/src/nn/transformer/pwff.rs index bd168b643..1c7af0149 100644 --- a/crates/burn-core/src/nn/transformer/pwff.rs +++ b/crates/burn-core/src/nn/transformer/pwff.rs @@ -1,9 +1,9 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::nn::Initializer; use crate::{ config::Config, - module::Module, nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig}, tensor::{backend::Backend, Tensor}, }; @@ -36,11 +36,34 @@ pub struct PositionWiseFeedForwardConfig { /// /// Should be created using [PositionWiseFeedForwardConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct PositionWiseFeedForward { - linear_inner: Linear, - linear_outer: Linear, - dropout: Dropout, - gelu: Gelu, + /// Linear layer with `d_model` input features and `d_ff` output features. + pub linear_inner: Linear, + /// Linear layer with `d_ff` input features and `d_model` output features. + pub linear_outer: Linear, + /// Dropout layer. + pub dropout: Dropout, + /// GELU activation function. + pub gelu: Gelu, +} + +impl ModuleDisplay for PositionWiseFeedForward { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_model, dff] = self.linear_inner.weight.shape().dims; + + content + .add("d_model", &d_model) + .add("d_ff", &dff) + .add("prob", &self.dropout.prob) + .optional() + } } impl PositionWiseFeedForwardConfig { @@ -74,3 +97,20 @@ impl PositionWiseFeedForward { self.linear_outer.forward(x) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn display() { + let config = PositionWiseFeedForwardConfig::new(2, 4); + let pwff = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", pwff), + "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}" + ); + } +} diff --git a/crates/burn-core/src/nn/unfold.rs b/crates/burn-core/src/nn/unfold.rs index c958883e2..41d2cedbb 100644 --- a/crates/burn-core/src/nn/unfold.rs +++ b/crates/burn-core/src/nn/unfold.rs @@ -1,7 +1,8 @@ use crate as burn; use crate::config::Config; -use crate::module::{Ignored, Module}; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; + use burn_tensor::backend::Backend; use burn_tensor::module::unfold4d; use burn_tensor::ops::UnfoldOptions; @@ -27,15 +28,43 @@ pub struct Unfold4dConfig { /// /// Should be created with [Unfold4dConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct Unfold4d { - config: Ignored, + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + pub stride: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// The padding configuration. + pub padding: [usize; 2], +} + +impl ModuleDisplay for Unfold4d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("dilation", &alloc::format!("{:?}", &self.dilation)) + .add("padding", &alloc::format!("{:?}", &self.padding)) + .optional() + } } impl Unfold4dConfig { /// Initializes a new [Unfold4d] module. pub fn init(&self) -> Unfold4d { Unfold4d { - config: Ignored(self.clone()), + kernel_size: self.kernel_size, + stride: self.stride, + dilation: self.dilation, + padding: self.padding, } } } @@ -52,12 +81,24 @@ impl Unfold4d { pub fn forward(&self, input: Tensor) -> Tensor { unfold4d( input, - self.config.kernel_size, - UnfoldOptions::new( - self.config.stride, - self.config.padding, - self.config.dilation, - ), + self.kernel_size, + UnfoldOptions::new(self.stride, self.padding, self.dilation), ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = Unfold4dConfig::new([3, 3]); + let unfold = config.init(); + + assert_eq!( + alloc::format!("{}", unfold), + "Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}" + ); + } +} diff --git a/crates/burn-import/pytorch-tests/tests/linear/mod.rs b/crates/burn-import/pytorch-tests/tests/linear/mod.rs index 4244a35a4..3ba09fa0e 100644 --- a/crates/burn-import/pytorch-tests/tests/linear/mod.rs +++ b/crates/burn-import/pytorch-tests/tests/linear/mod.rs @@ -16,7 +16,7 @@ impl Net { pub fn init(device: &B::Device) -> Self { let fc1 = LinearConfig::new(2, 3).init(device); let fc2 = LinearConfig::new(3, 4).init(device); - let relu = Relu::default(); + let relu = Relu; Self { fc1, fc2, relu } } diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index 88966ddae..30f26ef56 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -101,6 +101,7 @@ impl NodeCodegen for ConvTranspose2dNode { groups: ConstantRecord::new(), padding: [ConstantRecord::new(); 2], padding_out: [ConstantRecord::new(); 2], + channels: [ConstantRecord::new(); 2], }; let item = Record::into_item::(record); diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs index 57c948a3a..59f9baed2 100644 --- a/crates/burn-import/src/burn/node/prelu.rs +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -1,7 +1,7 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type}; use burn::{ - module::{Param, ParamId}, + module::{ConstantRecord, Param, ParamId}, nn::{PReluConfig, PReluRecord}, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, @@ -70,6 +70,7 @@ impl NodeCodegen for PReluNode { ParamId::new(), Tensor::from_data(self.alpha.clone().convert::(), &device), ), + alpha_value: ConstantRecord, }; let item = Record::into_item::(record); diff --git a/crates/burn-train/src/learner/train_val.rs b/crates/burn-train/src/learner/train_val.rs index cbca2895a..ed01271fc 100644 --- a/crates/burn-train/src/learner/train_val.rs +++ b/crates/burn-train/src/learner/train_val.rs @@ -122,7 +122,7 @@ impl Learner { >::InnerModule: ValidStep, LC::EventProcessor: EventProcessor, { - log::info!("Fitting {}", self.model.to_string()); + log::info!("Fitting the model:\n {}", self.model.to_string()); // The reference model is always on the first device provided. if let Some(device) = self.devices.first() { self.model = self.model.fork(device);