Print module - implement module display for remaining modules (part2) (#1933)

This commit is contained in:
Dilshod Tadjibaev 2024-06-28 07:37:40 -05:00 committed by GitHub
parent 1ae1c03b2d
commit 98a58c867d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1340 additions and 153 deletions

View File

@ -1,10 +1,10 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::nn::cache::TensorCache;
use crate::nn::Initializer;
use crate::{
config::Config,
module::Module,
nn,
tensor::{activation, backend::Backend, Bool, Tensor},
};
@ -53,17 +53,49 @@ pub struct MultiHeadAttentionConfig {
///
/// Should be created with [MultiHeadAttentionConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct MultiHeadAttention<B: Backend> {
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
output: nn::Linear<B>,
dropout: nn::Dropout,
activation: nn::Gelu,
n_heads: usize,
d_k: usize,
min_float: f64,
quiet_softmax: bool,
/// Linear layer to transform the input features into the query space.
pub query: nn::Linear<B>,
/// Linear layer to transform the input features into the key space.
pub key: nn::Linear<B>,
/// Linear layer to transform the input features into the value space.
pub value: nn::Linear<B>,
/// Linear layer to transform the output features back to the original space.
pub output: nn::Linear<B>,
/// Dropout layer.
pub dropout: nn::Dropout,
/// Activation function.
pub activation: nn::Gelu,
/// The size of each linear layer.
pub d_model: usize,
/// The number of heads.
pub n_heads: usize,
/// Size of the key and query vectors.
pub d_k: usize,
/// Minimum value a float can take.
pub min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("n_heads", &self.n_heads)
.add("d_k", &self.d_k)
.add("dropout", &self.dropout.prob)
.add("min_float", &self.min_float)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
@ -99,6 +131,7 @@ impl MultiHeadAttentionConfig {
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
d_model: self.d_model,
}
}
}
@ -478,4 +511,16 @@ mod tests {
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
#[test]
fn display() {
let config = MultiHeadAttentionConfig::new(2, 4);
let mha = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", mha),
"MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
);
}
}

View File

@ -50,11 +50,16 @@ pub struct Conv1d<B: Backend> {
pub weight: Param<Tensor<B, 3>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: usize,
kernel_size: usize,
dilation: usize,
groups: usize,
padding: Ignored<PaddingConfig1d>,
/// Stride of the convolution.
pub stride: usize,
/// Size of the kernel.
pub kernel_size: usize,
/// Spacing between kernel elements.
pub dilation: usize,
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: Ignored<PaddingConfig1d>,
}
impl<B: Backend> ModuleDisplay for Conv1d<B> {
@ -169,4 +174,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}
#[test]
fn display() {
let config = Conv1dConfig::new(5, 5, 5);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", conv),
"Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
);
}
}

View File

@ -52,11 +52,16 @@ pub struct Conv2d<B: Backend> {
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: Ignored<PaddingConfig2d>,
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
}
impl Conv2dConfig {
@ -214,4 +219,15 @@ mod tests {
assert_eq!(config.initializer, init);
}
#[test]
fn display() {
let config = Conv2dConfig::new([5, 1], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", conv),
"Conv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
);
}
}

View File

@ -1,7 +1,12 @@
use alloc::format;
use crate as burn;
use crate::config::Config;
use crate::module::Content;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
use crate::module::Param;
use crate::nn::conv::checks;
use crate::nn::Initializer;
@ -45,17 +50,46 @@ pub struct ConvTranspose1dConfig {
/// Applies a 1D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose1d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
pub weight: Param<Tensor<B, 3>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: usize,
kernel_size: usize,
dilation: usize,
groups: usize,
padding: usize,
padding_out: usize,
/// Stride of the convolution.
pub stride: usize,
/// Size of the kernel.
pub kernel_size: usize,
/// Spacing between kernel elements.
pub dilation: usize,
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: usize,
/// The padding output configuration.
pub padding_out: usize,
/// The number of channels.
pub channels: [usize; 2],
}
impl<B: Backend> ModuleDisplay for ConvTranspose1d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &self.stride)
.add("kernel_size", &self.kernel_size)
.add("dilation", &self.dilation)
.add("groups", &self.groups)
.add("padding", &self.padding)
.add("padding_out", &self.padding_out)
.optional()
}
}
impl ConvTranspose1dConfig {
@ -91,6 +125,7 @@ impl ConvTranspose1dConfig {
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
@ -150,4 +185,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}
#[test]
fn display() {
let config = ConvTranspose1dConfig::new([5, 2], 5);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", conv),
"ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}"
);
}
}

View File

@ -1,7 +1,12 @@
use alloc::format;
use crate as burn;
use crate::config::Config;
use crate::module::Content;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
use crate::module::Param;
use crate::nn::conv::checks;
use crate::nn::Initializer;
@ -45,17 +50,46 @@ pub struct ConvTranspose2dConfig {
/// Applies a 2D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose2d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: [usize; 2],
padding_out: [usize; 2],
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: [usize; 2],
/// Padding output configuration.
pub padding_out: [usize; 2],
/// Number of channels.
pub channels: [usize; 2],
}
impl<B: Backend> ModuleDisplay for ConvTranspose2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &format!("{:?}", &self.stride))
.add("kernel_size", &format!("{:?}", &self.kernel_size))
.add("dilation", &format!("{:?}", &self.dilation))
.add("groups", &self.groups)
.add("padding", &format!("{:?}", &self.padding))
.add("padding_out", &format!("{:?}", &self.padding_out))
.optional()
}
}
impl ConvTranspose2dConfig {
@ -92,6 +126,7 @@ impl ConvTranspose2dConfig {
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
@ -152,4 +187,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}
#[test]
fn display() {
let config = ConvTranspose2dConfig::new([5, 2], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", conv),
"ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}"
);
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::{DisplaySettings, Module, ModuleDisplay};
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};
@ -23,7 +23,8 @@ pub struct DropoutConfig {
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Dropout {
prob: f64,
/// The probability of randomly zeroes some elements of the input tensor during training.
pub prob: f64,
}
impl DropoutConfig {
@ -62,7 +63,7 @@ impl ModuleDisplay for Dropout {
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("prob", &self.prob).optional()
}
}
@ -99,4 +100,12 @@ mod tests {
assert_eq!(tensor.to_data(), output.to_data());
}
#[test]
fn display() {
let config = DropoutConfig::new(0.5);
let layer = config.init();
assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}");
}
}

View File

@ -4,6 +4,7 @@ use super::Initializer;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Int;
use crate::tensor::Tensor;
@ -26,12 +27,29 @@ pub struct EmbeddingConfig {
///
/// Should be created with [EmbeddingConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Embedding<B: Backend> {
/// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
/// from a normal distribution `N(0, 1)`.
pub weight: Param<Tensor<B, 2>>,
}
impl<B: Backend> ModuleDisplay for Embedding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [n_embedding, d_model] = self.weight.shape().dims;
content
.add("n_embedding", &n_embedding)
.add("d_model", &d_model)
.optional()
}
}
impl EmbeddingConfig {
/// Initialize a new [embedding](Embedding) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
@ -100,4 +118,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(embed.weight.shape()), 3);
}
#[test]
fn display() {
let config = EmbeddingConfig::new(100, 10);
let embed = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", embed),
"Embedding {n_embedding: 100, d_model: 10, params: 1000}"
);
}
}

View File

@ -7,7 +7,7 @@ use crate::tensor::Tensor;
/// Applies the Gaussian Error Linear Units function element-wise.
/// See also [gelu](burn::tensor::activation::gelu)
#[derive(Module, Clone, Debug, Default)]
pub struct Gelu {}
pub struct Gelu;
impl Gelu {
/// Create the module.
@ -25,3 +25,15 @@ impl Gelu {
crate::tensor::activation::gelu(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Gelu::new();
assert_eq!(alloc::format!("{}", layer), "Gelu");
}
}

View File

@ -1,6 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -10,6 +11,7 @@ use crate::tensor::activation::leaky_relu;
///
/// Should be created with [LeakyReluConfig](LeakyReluConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct LeakyRelu {
/// The negative slope.
pub negative_slope: f64,
@ -30,6 +32,20 @@ impl LeakyReluConfig {
}
}
impl ModuleDisplay for LeakyRelu {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("negative_slope", &self.negative_slope)
.optional()
}
}
impl LeakyRelu {
/// Forward pass for the Leaky ReLu layer.
///
@ -92,4 +108,13 @@ mod tests {
let actual_output = model.forward(input_data);
actual_output.to_data().assert_approx_eq(&expected, 4)
}
#[test]
fn display() {
let config = LeakyReluConfig::new().init();
assert_eq!(
alloc::format!("{}", config),
"LeakyRelu {negative_slope: 0.01}"
);
}
}

View File

@ -1,10 +1,8 @@
use crate as burn;
use crate::module::DisplaySettings;
use crate::module::ModuleDisplay;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::{backend::Backend, Tensor};
use super::Initializer;
@ -93,7 +91,7 @@ impl<B: Backend> ModuleDisplay for Linear<B> {
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, d_output] = self.weight.shape().dims;
content
.add("d_input", &d_input)
@ -196,4 +194,15 @@ mod tests {
assert_eq!(result_1d.into_data(), result_2d.into_data());
}
#[test]
fn display() {
let config = LinearConfig::new(3, 5);
let linear = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", linear),
"Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
);
}
}

View File

@ -1,4 +1,5 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::activation::log_sigmoid;
use crate::tensor::{backend::Backend, Int, Tensor};
@ -59,11 +60,30 @@ impl BinaryCrossEntropyLossConfig {
///
/// Should be created using [BinaryCrossEntropyLossConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BinaryCrossEntropyLoss<B: Backend> {
/// Weights for cross-entropy.
pub weights: Option<Tensor<B, 1>>,
smoothing: Option<f32>,
logits: bool,
/// Label smoothing alpha.
pub smoothing: Option<f32>,
/// Treat the inputs as logits
pub logits: bool,
}
impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("weights", &self.weights)
.add("smoothing", &self.smoothing)
.add("logits", &self.logits)
.optional()
}
}
impl<B: Backend> BinaryCrossEntropyLoss<B> {
@ -368,4 +388,16 @@ mod tests {
.init(&device)
.forward(logits, targets);
}
#[test]
fn display() {
let config =
BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
let loss = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", loss),
"BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
);
}
}

View File

@ -1,8 +1,10 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::activation::log_softmax;
use crate::tensor::{backend::Backend, Bool, Int, Tensor};
use crate::{config::Config, module::Module};
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
@ -71,12 +73,39 @@ impl CrossEntropyLossConfig {
///
/// Should be created using [CrossEntropyLossConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct CrossEntropyLoss<B: Backend> {
pad_tokens: Option<Vec<usize>>,
/// Pad tokens to ignore in the loss calculation.
pub pad_tokens: Option<Vec<usize>>,
/// Weights for cross-entropy.
pub weights: Option<Tensor<B, 1>>,
smoothing: Option<f32>,
logits: bool,
/// Label smoothing factor.
pub smoothing: Option<f32>,
/// Use logits as input.
pub logits: bool,
}
impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
alloc::format!("Vec<0..{}>", pad_tokens.len())
} else {
"None".to_string()
};
content
.add("pad_tokens", &pad_tokens)
.add("weights", &self.weights)
.add("smoothing", &self.smoothing)
.add("logits", &self.logits)
.optional()
}
}
impl<B: Backend> CrossEntropyLoss<B> {
@ -406,4 +435,17 @@ mod tests {
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
}
#[test]
fn display() {
let config = CrossEntropyLossConfig::new()
.with_weights(Some(alloc::vec![3., 7., 0.9]))
.with_smoothing(Some(0.5));
let loss = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", loss),
"CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
);
}
}

View File

@ -1,9 +1,9 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use crate::{config::Config, module::Module};
use core::marker::PhantomData;
use super::Reduction;
@ -16,15 +16,11 @@ pub struct HuberLossConfig {
impl HuberLossConfig {
/// Initialize [Huber loss](HuberLoss).
pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> {
// device is not needed as of now, but we might want to prepare some data on it
// and its consistent with other loss functions
let _ = device;
pub fn init(&self) -> HuberLoss {
self.assertions();
HuberLoss {
delta: self.delta,
lin_bias: self.delta * self.delta * 0.5,
_backend: PhantomData,
}
}
@ -52,14 +48,31 @@ impl HuberLossConfig {
/// This loss function is less sensitive to outliers than the mean squared error loss.
///
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
#[derive(Module, Debug)]
pub struct HuberLoss<B: Backend> {
delta: f32,
lin_bias: f32, // delta * delta * 0.5 precomputed
_backend: PhantomData<B>,
#[derive(Module, Debug, Clone)]
#[module(custom_display)]
pub struct HuberLoss {
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
pub delta: f32,
/// Precomputed value for the linear bias.
pub lin_bias: f32, // delta * delta * 0.5 precomputed
}
impl<B: Backend> HuberLoss<B> {
impl ModuleDisplay for HuberLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("delta", &self.delta)
.add("lin_bias", &self.lin_bias)
.optional()
}
}
impl HuberLoss {
/// Compute the loss element-wise for the predictions and targets, then reduce
/// to a single loss value.
///
@ -70,7 +83,7 @@ impl<B: Backend> HuberLoss<B> {
/// - predictions: \[...dims\]
/// - targets: \[...dims\]
/// - output: \[1\]
pub fn forward<const D: usize>(
pub fn forward<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
@ -89,7 +102,7 @@ impl<B: Backend> HuberLoss<B> {
/// - predictions: [...dims]
/// - targets: [...dims]
/// - output: [...dims]
pub fn forward_no_reduction<const D: usize>(
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
@ -103,7 +116,10 @@ impl<B: Backend> HuberLoss<B> {
///
/// - residuals: [...dims]
/// - output: [...dims]
pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> {
pub fn forward_residuals<const D: usize, B: Backend>(
&self,
residuals: Tensor<B, D>,
) -> Tensor<B, D> {
let is_large = residuals.clone().abs().greater_elem(self.delta);
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
// `sign()` function, in general, suffers from a jump at 0.
@ -138,7 +154,7 @@ mod tests {
let predict = TestTensor::<1>::from_data(predict, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let huber = HuberLossConfig::new(0.5).init(&device);
let huber = HuberLossConfig::new(0.5).init();
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
@ -166,7 +182,7 @@ mod tests {
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
let targets = TestAutodiffTensor::from_data(targets, &device);
let loss = HuberLossConfig::new(0.5).init(&device);
let loss = HuberLossConfig::new(0.5).init();
let loss = loss.forward_no_reduction(predict.clone(), targets);
let grads = loss.backward();
@ -175,4 +191,15 @@ mod tests {
let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
grads_predict.to_data().assert_approx_eq(&expected, 3);
}
#[test]
fn display() {
let config = HuberLossConfig::new(0.5);
let loss = config.init();
assert_eq!(
alloc::format!("{}", loss),
"HuberLoss {delta: 0.5, lin_bias: 0.125}"
);
}
}

View File

@ -1,26 +1,24 @@
use crate::nn::loss::reduction::Reduction;
use core::marker::PhantomData;
use crate as burn;
use crate::nn::loss::reduction::Reduction;
use crate::module::Module;
use crate::tensor::{backend::Backend, Tensor};
/// Calculate the mean squared error loss from the input logits and the targets.
#[derive(Clone, Debug)]
pub struct MseLoss<B: Backend> {
backend: PhantomData<B>,
}
#[derive(Module, Clone, Debug)]
pub struct MseLoss;
impl<B: Backend> Default for MseLoss<B> {
impl Default for MseLoss {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> MseLoss<B> {
impl MseLoss {
/// Create the criterion.
pub fn new() -> Self {
Self {
backend: PhantomData,
}
Self
}
/// Compute the criterion on the input tensor.
@ -29,7 +27,7 @@ impl<B: Backend> MseLoss<B> {
///
/// - logits: [batch_size, num_targets]
/// - targets: [batch_size, num_targets]
pub fn forward<const D: usize>(
pub fn forward<const D: usize, B: Backend>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D>,
@ -43,7 +41,7 @@ impl<B: Backend> MseLoss<B> {
}
/// Compute the criterion on the input tensor without reducing.
pub fn forward_no_reduction<const D: usize>(
pub fn forward_no_reduction<const D: usize, B: Backend>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D>,
@ -85,4 +83,10 @@ mod tests {
let expected = TensorData::from([6.0]);
loss_sum.into_data().assert_eq(&expected, false);
}
#[test]
fn display() {
let loss = MseLoss::new();
assert_eq!(alloc::format!("{}", loss), "MseLoss");
}
}

View File

@ -44,8 +44,10 @@ pub struct BatchNorm<B: Backend, const D: usize> {
pub running_mean: RunningState<Tensor<B, 1>>,
/// The running variance.
pub running_var: RunningState<Tensor<B, 1>>,
momentum: f64,
epsilon: f64,
/// Momentum used to update the metrics.
pub momentum: f64,
/// A value required for numerical stability.
pub epsilon: f64,
}
impl BatchNormConfig {
@ -424,4 +426,15 @@ mod tests_2d {
device,
)
}
#[test]
fn display() {
let batch_norm =
BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&Default::default());
assert_eq!(
format!("{}", batch_norm),
"BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
);
}
}

View File

@ -4,6 +4,7 @@ use crate::nn::Initializer;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -36,16 +37,37 @@ pub struct GroupNormConfig {
///
/// Should be created using [GroupNormConfig](GroupNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct GroupNorm<B: Backend> {
/// The learnable weight
pub gamma: Option<Param<Tensor<B, 1>>>,
/// The learnable bias
pub beta: Option<Param<Tensor<B, 1>>>,
/// The number of groups to separate the channels into
pub num_groups: usize,
/// The number of channels expected in the input
pub num_channels: usize,
/// A value required for numerical stability
pub epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
pub affine: bool,
}
pub(crate) num_groups: usize,
pub(crate) num_channels: usize,
pub(crate) epsilon: f64,
pub(crate) affine: bool,
impl<B: Backend> ModuleDisplay for GroupNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("num_groups", &self.num_groups)
.add("num_channels", &self.num_channels)
.add("epsilon", &self.epsilon)
.add("affine", &self.affine)
.optional()
}
}
impl GroupNormConfig {
@ -169,6 +191,7 @@ mod tests {
use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
use alloc::format;
#[test]
fn group_norm_forward_affine_false() {
@ -292,4 +315,15 @@ mod tests {
]);
output.to_data().assert_approx_eq(&expected, 3);
}
#[test]
fn display() {
let config = GroupNormConfig::new(3, 6);
let group_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", group_norm),
"GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
);
}
}

View File

@ -1,6 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Module, Param};
use crate::nn::norm::group_norm;
use crate::nn::Initializer;
@ -25,15 +26,34 @@ pub struct InstanceNormConfig {
///
/// Should be created using [InstanceNormConfig](InstanceNormConfig).
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct InstanceNorm<B: Backend> {
/// The learnable weight
pub gamma: Option<Param<Tensor<B, 1>>>,
/// The learnable bias
pub beta: Option<Param<Tensor<B, 1>>>,
/// The number of channels expected in the input
pub num_channels: usize,
/// A value required for numerical stability
pub epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
pub affine: bool,
}
num_channels: usize,
epsilon: f64,
affine: bool,
impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("num_channels", &self.num_channels)
.add("epsilon", &self.epsilon)
.add("affine", &self.affine)
.optional()
}
}
impl InstanceNormConfig {
@ -83,6 +103,7 @@ mod tests {
use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
use alloc::format;
#[test]
fn instance_norm_forward_affine_false() {
@ -187,4 +208,15 @@ mod tests {
]);
output.to_data().assert_approx_eq(&expected, 3);
}
#[test]
fn display() {
let config = InstanceNormConfig::new(6);
let instance_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", instance_norm),
"InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
);
}
}

View File

@ -1,5 +1,6 @@
use crate as burn;
use crate::config::Config;
use crate::module::Content;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
@ -33,9 +34,10 @@ pub struct LayerNormConfig {
#[module(custom_display)]
pub struct LayerNorm<B: Backend> {
/// The learnable weight.
gamma: Param<Tensor<B, 1>>,
pub gamma: Param<Tensor<B, 1>>,
/// The learnable bias.
beta: Param<Tensor<B, 1>>,
pub beta: Param<Tensor<B, 1>>,
/// A value required for numerical stability.
epsilon: f64,
}
@ -80,7 +82,7 @@ impl<B: Backend> ModuleDisplay for LayerNorm<B> {
.optional()
}
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model] = self.gamma.shape().dims;
content
.add("d_model", &d_model)
@ -93,6 +95,7 @@ impl<B: Backend> ModuleDisplay for LayerNorm<B> {
mod tests {
use super::*;
use crate::tensor::TensorData;
use alloc::format;
#[cfg(feature = "std")]
use crate::{TestAutodiffBackend, TestBackend};
@ -157,4 +160,15 @@ mod tests {
let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
tensor_2_grad.to_data().assert_approx_eq(&expected, 3);
}
#[test]
fn display() {
let config = LayerNormConfig::new(6);
let layer_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", layer_norm),
"LayerNorm {d_model: 6, epsilon: 0.00001, params: 12}"
);
}
}

View File

@ -3,6 +3,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::nn::Initializer;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -48,11 +49,12 @@ impl RmsNormConfig {
///
/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct RmsNorm<B: Backend> {
/// The learnable parameter to scale the normalized tensor
pub gamma: Param<Tensor<B, 1>>,
/// A value required for numerical stability
epsilon: f64,
pub epsilon: f64,
}
impl<B: Backend> RmsNorm<B> {
@ -71,11 +73,28 @@ impl<B: Backend> RmsNorm<B> {
}
}
impl<B: Backend> ModuleDisplay for RmsNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model] = self.gamma.shape().dims;
content
.add("d_model", &d_model)
.add("epsilon", &self.epsilon)
.optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
use alloc::format;
#[test]
fn rms_norm_forward() {
@ -95,4 +114,15 @@ mod tests {
]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn display() {
let config = RmsNormConfig::new(6);
let layer_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", layer_norm),
"RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
);
}
}

View File

@ -2,6 +2,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -18,8 +19,22 @@ pub struct AdaptiveAvgPool1dConfig {
///
/// Should be created with [AdaptiveAvgPool1dConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AdaptiveAvgPool1d {
output_size: usize,
/// The size of the output.
pub output_size: usize,
}
impl ModuleDisplay for AdaptiveAvgPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("output_size", &self.output_size).optional()
}
}
impl AdaptiveAvgPool1dConfig {
@ -44,3 +59,19 @@ impl AdaptiveAvgPool1d {
adaptive_avg_pool1d(input, self.output_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = AdaptiveAvgPool1dConfig::new(3);
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"AdaptiveAvgPool1d {output_size: 3}"
);
}
}

View File

@ -2,6 +2,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -18,8 +19,24 @@ pub struct AdaptiveAvgPool2dConfig {
///
/// Should be created with [AdaptiveAvgPool2dConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AdaptiveAvgPool2d {
output_size: [usize; 2],
/// The size of the output.
pub output_size: [usize; 2],
}
impl ModuleDisplay for AdaptiveAvgPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let output_size = alloc::format!("{:?}", self.output_size);
content.add("output_size", &output_size).optional()
}
}
impl AdaptiveAvgPool2dConfig {
@ -44,3 +61,19 @@ impl AdaptiveAvgPool2d {
adaptive_avg_pool2d(input, self.output_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = AdaptiveAvgPool2dConfig::new([3, 3]);
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"AdaptiveAvgPool2d {output_size: [3, 3]}"
);
}
}

View File

@ -1,6 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
@ -40,11 +41,33 @@ pub struct AvgPool1dConfig {
/// [Issue 636](https://github.com/tracel-ai/burn/issues/636)
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AvgPool1d {
stride: usize,
kernel_size: usize,
padding: Ignored<PaddingConfig1d>,
count_include_pad: bool,
/// The stride.
pub stride: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig1d>,
/// If the padding is counted in the denominator when computing the average.
pub count_include_pad: bool,
}
impl ModuleDisplay for AvgPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &self.kernel_size)
.add("stride", &self.stride)
.add("padding", &self.padding)
.add("count_include_pad", &self.count_include_pad)
.optional()
}
}
impl AvgPool1dConfig {
@ -83,3 +106,19 @@ impl AvgPool1d {
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = AvgPool1dConfig::new(3);
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"AvgPool1d {kernel_size: 3, stride: 1, padding: Valid, count_include_pad: true}"
);
}
}

View File

@ -1,6 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
@ -39,11 +40,33 @@ pub struct AvgPool2dConfig {
/// TODO: Add support for `count_include_pad=False`, see
/// [Issue 636](https://github.com/tracel-ai/burn/issues/636)
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AvgPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: Ignored<PaddingConfig2d>,
count_include_pad: bool,
/// Stride of the pooling.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Padding configuration.
pub padding: Ignored<PaddingConfig2d>,
/// If the padding is counted in the denominator when computing the average.
pub count_include_pad: bool,
}
impl ModuleDisplay for AvgPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("padding", &self.padding)
.add("count_include_pad", &self.count_include_pad)
.optional()
}
}
impl AvgPool2dConfig {
@ -82,3 +105,20 @@ impl AvgPool2d {
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = AvgPool2dConfig::new([3, 3]);
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"AvgPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, count_include_pad: true}"
);
}
}

View File

@ -1,6 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
@ -28,11 +29,33 @@ pub struct MaxPool1dConfig {
///
/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct MaxPool1d {
stride: usize,
kernel_size: usize,
padding: Ignored<PaddingConfig1d>,
dilation: usize,
/// The stride.
pub stride: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig1d>,
/// The dilation.
pub dilation: usize,
}
impl ModuleDisplay for MaxPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &self.kernel_size)
.add("stride", &self.stride)
.add("padding", &self.padding)
.add("dilation", &self.dilation)
.optional()
}
}
impl MaxPool1dConfig {
@ -65,3 +88,20 @@ impl MaxPool1d {
max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = MaxPool1dConfig::new(3);
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}"
);
}
}

View File

@ -1,6 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
@ -28,11 +29,33 @@ pub struct MaxPool2dConfig {
///
/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct MaxPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: Ignored<PaddingConfig2d>,
dilation: [usize; 2],
/// The strides.
pub stride: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
/// The dilation.
pub dilation: [usize; 2],
}
impl ModuleDisplay for MaxPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("padding", &self.padding)
.add("dilation", &alloc::format!("{:?}", &self.dilation))
.optional()
}
}
impl MaxPool2dConfig {
@ -65,3 +88,20 @@ impl MaxPool2d {
max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = MaxPool2dConfig::new([3, 3]);
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}"
);
}
}

View File

@ -2,7 +2,8 @@ use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use crate::tensor::TensorData;
@ -40,8 +41,31 @@ pub struct PositionalEncodingConfig {
///
/// Should be created using [PositionalEncodingConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionalEncoding<B: Backend> {
sinusoids: Tensor<B, 3>,
/// The sinusoids used to add positional information to the input embeddings.
pub sinusoids: Tensor<B, 3>,
/// The maximum sequence size to use.
pub max_sequence_size: usize,
/// Max time scale to use.
pub max_timescale: usize,
}
impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [_, _, d_model] = self.sinusoids.shape().dims;
content
.add("d_model", &d_model)
.add("max_sequence_size", &self.max_sequence_size)
.add("max_timescale", &self.max_timescale)
.optional()
}
}
impl PositionalEncodingConfig {
@ -55,7 +79,11 @@ impl PositionalEncodingConfig {
)
.unsqueeze::<3>();
PositionalEncoding { sinusoids }
PositionalEncoding {
sinusoids,
max_sequence_size: self.max_sequence_size,
max_timescale: self.max_timescale,
}
}
}
@ -245,4 +273,15 @@ mod tests {
let input = Tensor::zeros([1, 6_000, d_model], &device);
let _output = pe.forward(input);
}
#[test]
fn display() {
let config = PositionalEncodingConfig::new(4);
let pe = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", pe),
"PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
);
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::nn::Initializer;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -9,11 +9,33 @@ use crate::tensor::Tensor;
///
/// Should be created using [PReluConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PRelu<B: Backend> {
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
/// be the same as number of channels in the input tensor
pub alpha: Param<Tensor<B, 1>>,
/// Alpha value for the PRelu layer
pub alpha_value: f64,
}
impl<B: Backend> ModuleDisplay for PRelu<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [num_parameters] = self.alpha.shape().dims;
content
.add("num_parameters", &num_parameters)
.add("alpha_value", &self.alpha_value)
.optional()
}
}
/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).
#[derive(Config, Debug)]
pub struct PReluConfig {
@ -24,12 +46,14 @@ pub struct PReluConfig {
#[config(default = "0.25")]
pub alpha: f64,
}
impl PReluConfig {
/// Initialize a new [Parametric Relu](PRelu) Layer
pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
PRelu {
// alpha is a tensor of length num_parameters
alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
alpha_value: self.alpha,
}
}
}
@ -47,3 +71,19 @@ impl<B: Backend> PRelu<B> {
crate::tensor::activation::prelu(input, self.alpha.val())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn display() {
let layer = PReluConfig::new().init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", layer),
"PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}"
);
}
}

View File

@ -8,7 +8,7 @@ use crate::tensor::Tensor;
/// See also [relu](burn::tensor::activation::relu)
///
#[derive(Module, Clone, Debug, Default)]
pub struct Relu {}
pub struct Relu;
impl Relu {
/// Create the module.
@ -25,3 +25,15 @@ impl Relu {
crate::tensor::activation::relu(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Relu::new();
assert_eq!(alloc::format!("{}", layer), "Relu");
}
}

View File

@ -2,6 +2,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::nn::rnn::gate_controller;
use crate::nn::Initializer;
use crate::tensor::activation;
@ -30,11 +31,35 @@ pub struct GruConfig {
///
/// Should be created with [GruConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Gru<B: Backend> {
update_gate: GateController<B>,
reset_gate: GateController<B>,
new_gate: GateController<B>,
d_hidden: usize,
/// The update gate controller.
pub update_gate: GateController<B>,
/// The reset gate controller.
pub reset_gate: GateController<B>,
/// The new gate controller.
pub new_gate: GateController<B>,
/// The size of the hidden state.
pub d_hidden: usize,
}
impl<B: Backend> ModuleDisplay for Gru<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.update_gate.input_transform.weight.shape().dims;
let bias = self.update_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl GruConfig {
@ -274,4 +299,16 @@ mod tests {
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
}
#[test]
fn display() {
let config = GruConfig::new(2, 8, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", layer),
"Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}"
);
}
}

View File

@ -2,6 +2,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::nn::rnn::gate_controller::GateController;
use crate::nn::Initializer;
use crate::tensor::activation;
@ -43,6 +44,7 @@ pub struct LstmConfig {
///
/// Should be created with [LstmConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Lstm<B: Backend> {
/// The input gate regulates which information to update and store in the cell state at each time step.
pub input_gate: GateController<B>,
@ -52,7 +54,27 @@ pub struct Lstm<B: Backend> {
pub output_gate: GateController<B>,
/// The cell gate is used to compute the cell state that stores and carries information through time.
pub cell_gate: GateController<B>,
d_hidden: usize,
/// The hidden state of the LSTM.
pub d_hidden: usize,
}
impl<B: Backend> ModuleDisplay for Lstm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.input_gate.input_transform.weight.shape().dims;
let bias = self.input_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl LstmConfig {
@ -195,12 +217,33 @@ pub struct BiLstmConfig {
///
/// Should be created with [BiLstmConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BiLstm<B: Backend> {
/// LSTM for the forward direction.
pub forward: Lstm<B>,
/// LSTM for the reverse direction.
pub reverse: Lstm<B>,
d_hidden: usize,
/// The size of the hidden state.
pub d_hidden: usize,
}
impl<B: Backend> ModuleDisplay for BiLstm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.forward.input_gate.input_transform.weight.shape().dims;
let bias = self.forward.input_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl BiLstmConfig {
@ -693,4 +736,28 @@ mod tests {
.to_data()
.assert_approx_eq(&expected_cn_without_init_state, 3);
}
#[test]
fn display_lstm() {
let config = LstmConfig::new(2, 3, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", layer),
"Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
);
}
#[test]
fn display_bilstm() {
let config = BiLstmConfig::new(2, 3, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", layer),
"BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
);
}
}

View File

@ -1,6 +1,6 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Int;
use crate::tensor::Tensor;
@ -74,7 +74,11 @@ impl RotaryEncodingConfig {
.repeat(2, 2)
.reshape([self.max_sequence_length, self.d_model, 2]);
RotaryEncoding { freq_complex }
RotaryEncoding {
freq_complex,
max_sequence_length: self.max_sequence_length,
theta: self.theta,
}
}
}
@ -87,9 +91,31 @@ impl RotaryEncodingConfig {
///
/// Should be created using [RotaryEncodingConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct RotaryEncoding<B: Backend> {
/// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
freq_complex: Tensor<B, 3>,
pub freq_complex: Tensor<B, 3>,
/// Maximum sequence length of input
pub max_sequence_length: usize,
/// Scaling factor for frequency computation.
pub theta: f32,
}
impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [_, _, d_model] = self.freq_complex.shape().dims;
content
.add("d_model", &d_model)
.add("max_sequence_length", &self.max_sequence_length)
.add("theta", &self.theta)
.optional()
}
}
#[allow(clippy::single_range_in_vec_init)]
@ -238,4 +264,15 @@ mod tests {
let input = Tensor::zeros([1, 5, d_model], &device);
let _output = pe.forward(input);
}
#[test]
fn display() {
let config = RotaryEncodingConfig::new(10, 4);
let pe = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", pe),
"RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}"
);
}
}

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::activation::silu;
use crate::tensor::{backend::Backend, Tensor};
@ -31,6 +31,7 @@ pub struct SwiGluConfig {
///
/// Should be created with [SwiGluConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct SwiGlu<B: Backend> {
/// The inner linear layer for Swish activation function
/// with `d_input` input features and `d_output` output features.
@ -40,6 +41,23 @@ pub struct SwiGlu<B: Backend> {
pub linear_outer: Linear<B>,
}
impl<B: Backend> ModuleDisplay for SwiGlu<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, d_output] = self.linear_inner.weight.shape().dims;
content
.add("d_input", &d_input)
.add("d_output", &d_output)
.add("bias", &self.linear_inner.bias.is_some())
.optional()
}
}
impl SwiGluConfig {
/// Initialize a new [SwiGLU](SwiGlu) activation layer.
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
@ -112,4 +130,15 @@ mod tests {
.to_data()
.assert_approx_eq(&expected_output.to_data(), 4);
}
#[test]
fn display() {
let config = SwiGluConfig::new(3, 5);
let swiglu = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", swiglu),
"SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}"
);
}
}

View File

@ -7,7 +7,7 @@ use crate::tensor::Tensor;
/// Applies the tanh activation function element-wise
/// See also [tanh](burn::tensor::activation::tanh)
#[derive(Module, Clone, Debug, Default)]
pub struct Tanh {}
pub struct Tanh;
impl Tanh {
/// Create the module.
@ -24,3 +24,15 @@ impl Tanh {
crate::tensor::activation::tanh(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let layer = Tanh::new();
assert_eq!(alloc::format!("{}", layer), "Tanh");
}
}

View File

@ -1,15 +1,15 @@
use crate::tensor::Bool;
use alloc::vec::Vec;
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::Bool;
use crate::{
self as burn,
nn::{attention::MhaCache, cache::TensorCache, Initializer},
};
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::{
config::Config,
module::Module,
nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
@ -57,8 +57,51 @@ pub struct TransformerDecoderConfig {
///
/// Should be created using [TransformerDecoderConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct TransformerDecoder<B: Backend> {
layers: Vec<TransformerDecoderLayer<B>>,
/// Transformer decoder layers.
pub layers: Vec<TransformerDecoderLayer<B>>,
/// The size of the model.
pub d_model: usize,
/// The size of the position-wise feed-forward network.
pub d_ff: usize,
/// The number of attention heads.
pub n_heads: usize,
/// The number of layers.
pub n_layers: usize,
/// The dropout rate. Default: 0.1
pub dropout: f64,
/// Layer norm will be applied first instead of after the other modules.
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("d_ff", &self.d_ff)
.add("n_heads", &self.n_heads)
.add("n_layers", &self.n_layers)
.add("dropout", &self.dropout)
.add("norm_first", &self.norm_first)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
impl TransformerDecoderConfig {
@ -68,7 +111,16 @@ impl TransformerDecoderConfig {
.map(|_| TransformerDecoderLayer::new(self, device))
.collect::<Vec<_>>();
TransformerDecoder { layers }
TransformerDecoder {
layers,
d_model: self.d_model,
d_ff: self.d_ff,
n_heads: self.n_heads,
n_layers: self.n_layers,
dropout: self.dropout,
norm_first: self.norm_first,
quiet_softmax: self.quiet_softmax,
}
}
}
@ -473,4 +525,16 @@ mod tests {
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
#[test]
fn display() {
let config = TransformerDecoderConfig::new(2, 4, 2, 3);
let transformer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", transformer),
"TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
);
}
}

View File

@ -1,15 +1,14 @@
use crate::tensor::Bool;
use alloc::vec::Vec;
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::{
self as burn,
nn::{attention::MhaCache, cache::TensorCache, Initializer},
};
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::{
config::Config,
module::Module,
nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
@ -57,8 +56,51 @@ pub struct TransformerEncoderConfig {
///
/// Should be created using [TransformerEncoderConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct TransformerEncoder<B: Backend> {
layers: Vec<TransformerEncoderLayer<B>>,
/// The transformer encoder layers.
pub layers: Vec<TransformerEncoderLayer<B>>,
/// The size of the model.
pub d_model: usize,
/// The size of the position-wise feed-forward network.
pub d_ff: usize,
/// The number of attention heads.
pub n_heads: usize,
/// The number of layers.
pub n_layers: usize,
/// The dropout rate. Default: 0.1
pub dropout: f64,
/// Layer norm will be applied first instead of after the other modules.
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("d_ff", &self.d_ff)
.add("n_heads", &self.n_heads)
.add("n_layers", &self.n_layers)
.add("dropout", &self.dropout)
.add("norm_first", &self.norm_first)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
@ -98,7 +140,16 @@ impl TransformerEncoderConfig {
.map(|_| TransformerEncoderLayer::new(self, device))
.collect::<Vec<_>>();
TransformerEncoder { layers }
TransformerEncoder {
layers,
d_model: self.d_model,
d_ff: self.d_ff,
n_heads: self.n_heads,
n_layers: self.n_layers,
dropout: self.dropout,
norm_first: self.norm_first,
quiet_softmax: self.quiet_softmax,
}
}
}
@ -392,4 +443,16 @@ mod tests {
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
#[test]
fn display() {
let config = TransformerEncoderConfig::new(2, 4, 2, 3);
let transformer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", transformer),
"TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
);
}
}

View File

@ -1,9 +1,9 @@
use crate as burn;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::nn::Initializer;
use crate::{
config::Config,
module::Module,
nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig},
tensor::{backend::Backend, Tensor},
};
@ -36,11 +36,34 @@ pub struct PositionWiseFeedForwardConfig {
///
/// Should be created using [PositionWiseFeedForwardConfig]
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: Gelu,
/// Linear layer with `d_model` input features and `d_ff` output features.
pub linear_inner: Linear<B>,
/// Linear layer with `d_ff` input features and `d_model` output features.
pub linear_outer: Linear<B>,
/// Dropout layer.
pub dropout: Dropout,
/// GELU activation function.
pub gelu: Gelu,
}
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model, dff] = self.linear_inner.weight.shape().dims;
content
.add("d_model", &d_model)
.add("d_ff", &dff)
.add("prob", &self.dropout.prob)
.optional()
}
}
impl PositionWiseFeedForwardConfig {
@ -74,3 +97,20 @@ impl<B: Backend> PositionWiseFeedForward<B> {
self.linear_outer.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn display() {
let config = PositionWiseFeedForwardConfig::new(2, 4);
let pwff = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", pwff),
"PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
);
}
}

View File

@ -1,7 +1,8 @@
use crate as burn;
use crate::config::Config;
use crate::module::{Ignored, Module};
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn_tensor::backend::Backend;
use burn_tensor::module::unfold4d;
use burn_tensor::ops::UnfoldOptions;
@ -27,15 +28,43 @@ pub struct Unfold4dConfig {
///
/// Should be created with [Unfold4dConfig].
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Unfold4d {
config: Ignored<Unfold4dConfig>,
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
pub stride: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// The padding configuration.
pub padding: [usize; 2],
}
impl ModuleDisplay for Unfold4d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("dilation", &alloc::format!("{:?}", &self.dilation))
.add("padding", &alloc::format!("{:?}", &self.padding))
.optional()
}
}
impl Unfold4dConfig {
/// Initializes a new [Unfold4d] module.
pub fn init(&self) -> Unfold4d {
Unfold4d {
config: Ignored(self.clone()),
kernel_size: self.kernel_size,
stride: self.stride,
dilation: self.dilation,
padding: self.padding,
}
}
}
@ -52,12 +81,24 @@ impl Unfold4d {
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
unfold4d(
input,
self.config.kernel_size,
UnfoldOptions::new(
self.config.stride,
self.config.padding,
self.config.dilation,
),
self.kernel_size,
UnfoldOptions::new(self.stride, self.padding, self.dilation),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let config = Unfold4dConfig::new([3, 3]);
let unfold = config.init();
assert_eq!(
alloc::format!("{}", unfold),
"Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}"
);
}
}

View File

@ -16,7 +16,7 @@ impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let fc1 = LinearConfig::new(2, 3).init(device);
let fc2 = LinearConfig::new(3, 4).init(device);
let relu = Relu::default();
let relu = Relu;
Self { fc1, fc2, relu }
}

View File

@ -101,6 +101,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConvTranspose2dNode {
groups: ConstantRecord::new(),
padding: [ConstantRecord::new(); 2],
padding_out: [ConstantRecord::new(); 2],
channels: [ConstantRecord::new(); 2],
};
let item = Record::into_item::<PS>(record);

View File

@ -1,7 +1,7 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
use burn::{
module::{Param, ParamId},
module::{ConstantRecord, Param, ParamId},
nn::{PReluConfig, PReluRecord},
record::{PrecisionSettings, Record},
tensor::{Tensor, TensorData},
@ -70,6 +70,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode {
ParamId::new(),
Tensor::from_data(self.alpha.clone().convert::<PS::FloatElem>(), &device),
),
alpha_value: ConstantRecord,
};
let item = Record::into_item::<PS>(record);

View File

@ -122,7 +122,7 @@ impl<LC: LearnerComponents> Learner<LC> {
<LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
{
log::info!("Fitting {}", self.model.to_string());
log::info!("Fitting the model:\n {}", self.model.to_string());
// The reference model is always on the first device provided.
if let Some(device) = self.devices.first() {
self.model = self.model.fork(device);