mirror of https://github.com/tracel-ai/burn.git
Print module - implement module display for remaining modules (part2) (#1933)
This commit is contained in:
parent
1ae1c03b2d
commit
98a58c867d
|
@ -1,10 +1,10 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::nn::cache::TensorCache;
|
||||
use crate::nn::Initializer;
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{activation, backend::Backend, Bool, Tensor},
|
||||
};
|
||||
|
@ -53,17 +53,49 @@ pub struct MultiHeadAttentionConfig {
|
|||
///
|
||||
/// Should be created with [MultiHeadAttentionConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct MultiHeadAttention<B: Backend> {
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
output: nn::Linear<B>,
|
||||
dropout: nn::Dropout,
|
||||
activation: nn::Gelu,
|
||||
n_heads: usize,
|
||||
d_k: usize,
|
||||
min_float: f64,
|
||||
quiet_softmax: bool,
|
||||
/// Linear layer to transform the input features into the query space.
|
||||
pub query: nn::Linear<B>,
|
||||
/// Linear layer to transform the input features into the key space.
|
||||
pub key: nn::Linear<B>,
|
||||
/// Linear layer to transform the input features into the value space.
|
||||
pub value: nn::Linear<B>,
|
||||
/// Linear layer to transform the output features back to the original space.
|
||||
pub output: nn::Linear<B>,
|
||||
/// Dropout layer.
|
||||
pub dropout: nn::Dropout,
|
||||
/// Activation function.
|
||||
pub activation: nn::Gelu,
|
||||
/// The size of each linear layer.
|
||||
pub d_model: usize,
|
||||
/// The number of heads.
|
||||
pub n_heads: usize,
|
||||
/// Size of the key and query vectors.
|
||||
pub d_k: usize,
|
||||
/// Minimum value a float can take.
|
||||
pub min_float: f64,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("d_model", &self.d_model)
|
||||
.add("n_heads", &self.n_heads)
|
||||
.add("d_k", &self.d_k)
|
||||
.add("dropout", &self.dropout.prob)
|
||||
.add("min_float", &self.min_float)
|
||||
.add("quiet_softmax", &self.quiet_softmax)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
|
||||
|
@ -99,6 +131,7 @@ impl MultiHeadAttentionConfig {
|
|||
d_k: self.d_model / self.n_heads,
|
||||
min_float: self.min_float,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
d_model: self.d_model,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -478,4 +511,16 @@ mod tests {
|
|||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = MultiHeadAttentionConfig::new(2, 4);
|
||||
let mha = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", mha),
|
||||
"MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
|
||||
dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,11 +50,16 @@ pub struct Conv1d<B: Backend> {
|
|||
pub weight: Param<Tensor<B, 3>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: usize,
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// Padding configuration.
|
||||
pub padding: Ignored<PaddingConfig1d>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Conv1d<B> {
|
||||
|
@ -169,4 +174,15 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Conv1dConfig::new(5, 5, 5);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", conv),
|
||||
"Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,11 +52,16 @@ pub struct Conv2d<B: Backend> {
|
|||
pub weight: Param<Tensor<B, 4>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
impl Conv2dConfig {
|
||||
|
@ -214,4 +219,15 @@ mod tests {
|
|||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Conv2dConfig::new([5, 1], [5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", conv),
|
||||
"Conv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use alloc::format;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Content;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::Module;
|
||||
use crate::module::ModuleDisplay;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::Initializer;
|
||||
|
@ -45,17 +50,46 @@ pub struct ConvTranspose1dConfig {
|
|||
|
||||
/// Applies a 1D transposed convolution over input tensors.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct ConvTranspose1d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
|
||||
pub weight: Param<Tensor<B, 3>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
padding: usize,
|
||||
padding_out: usize,
|
||||
/// Stride of the convolution.
|
||||
pub stride: usize,
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: usize,
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: usize,
|
||||
/// The padding output configuration.
|
||||
pub padding_out: usize,
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for ConvTranspose1d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("channels", &format!("{:?}", &self.channels))
|
||||
.add("stride", &self.stride)
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("dilation", &self.dilation)
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &self.padding)
|
||||
.add("padding_out", &self.padding_out)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConvTranspose1dConfig {
|
||||
|
@ -91,6 +125,7 @@ impl ConvTranspose1dConfig {
|
|||
groups: self.groups,
|
||||
padding: self.padding,
|
||||
padding_out: self.padding_out,
|
||||
channels: self.channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -150,4 +185,15 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ConvTranspose1dConfig::new([5, 2], 5);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", conv),
|
||||
"ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use alloc::format;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Content;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::Module;
|
||||
use crate::module::ModuleDisplay;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::Initializer;
|
||||
|
@ -45,17 +50,46 @@ pub struct ConvTranspose2dConfig {
|
|||
|
||||
/// Applies a 2D transposed convolution over input tensors.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct ConvTranspose2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
padding: [usize; 2],
|
||||
padding_out: [usize; 2],
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub groups: usize,
|
||||
/// Padding configuration.
|
||||
pub padding: [usize; 2],
|
||||
/// Padding output configuration.
|
||||
pub padding_out: [usize; 2],
|
||||
/// Number of channels.
|
||||
pub channels: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for ConvTranspose2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("channels", &format!("{:?}", &self.channels))
|
||||
.add("stride", &format!("{:?}", &self.stride))
|
||||
.add("kernel_size", &format!("{:?}", &self.kernel_size))
|
||||
.add("dilation", &format!("{:?}", &self.dilation))
|
||||
.add("groups", &self.groups)
|
||||
.add("padding", &format!("{:?}", &self.padding))
|
||||
.add("padding_out", &format!("{:?}", &self.padding_out))
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConvTranspose2dConfig {
|
||||
|
@ -92,6 +126,7 @@ impl ConvTranspose2dConfig {
|
|||
groups: self.groups,
|
||||
padding: self.padding,
|
||||
padding_out: self.padding_out,
|
||||
channels: self.channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -152,4 +187,15 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = ConvTranspose2dConfig::new([5, 2], [5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", conv),
|
||||
"ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, Tensor};
|
||||
|
||||
|
@ -23,7 +23,8 @@ pub struct DropoutConfig {
|
|||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Dropout {
|
||||
prob: f64,
|
||||
/// The probability of randomly zeroes some elements of the input tensor during training.
|
||||
pub prob: f64,
|
||||
}
|
||||
|
||||
impl DropoutConfig {
|
||||
|
@ -62,7 +63,7 @@ impl ModuleDisplay for Dropout {
|
|||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("prob", &self.prob).optional()
|
||||
}
|
||||
}
|
||||
|
@ -99,4 +100,12 @@ mod tests {
|
|||
|
||||
assert_eq!(tensor.to_data(), output.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = DropoutConfig::new(0.5);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use super::Initializer;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Int;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -26,12 +27,29 @@ pub struct EmbeddingConfig {
|
|||
///
|
||||
/// Should be created with [EmbeddingConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Embedding<B: Backend> {
|
||||
/// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
|
||||
/// from a normal distribution `N(0, 1)`.
|
||||
pub weight: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Embedding<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [n_embedding, d_model] = self.weight.shape().dims;
|
||||
content
|
||||
.add("n_embedding", &n_embedding)
|
||||
.add("d_model", &d_model)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
/// Initialize a new [embedding](Embedding) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {
|
||||
|
@ -100,4 +118,15 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&TensorData::zeros::<f32, _>(embed.weight.shape()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = EmbeddingConfig::new(100, 10);
|
||||
let embed = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", embed),
|
||||
"Embedding {n_embedding: 100, d_model: 10, params: 1000}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::tensor::Tensor;
|
|||
/// Applies the Gaussian Error Linear Units function element-wise.
|
||||
/// See also [gelu](burn::tensor::activation::gelu)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Gelu {}
|
||||
pub struct Gelu;
|
||||
|
||||
impl Gelu {
|
||||
/// Create the module.
|
||||
|
@ -25,3 +25,15 @@ impl Gelu {
|
|||
crate::tensor::activation::gelu(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Gelu::new();
|
||||
|
||||
assert_eq!(alloc::format!("{}", layer), "Gelu");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
|
@ -10,6 +11,7 @@ use crate::tensor::activation::leaky_relu;
|
|||
///
|
||||
/// Should be created with [LeakyReluConfig](LeakyReluConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct LeakyRelu {
|
||||
/// The negative slope.
|
||||
pub negative_slope: f64,
|
||||
|
@ -30,6 +32,20 @@ impl LeakyReluConfig {
|
|||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for LeakyRelu {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("negative_slope", &self.negative_slope)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl LeakyRelu {
|
||||
/// Forward pass for the Leaky ReLu layer.
|
||||
///
|
||||
|
@ -92,4 +108,13 @@ mod tests {
|
|||
let actual_output = model.forward(input_data);
|
||||
actual_output.to_data().assert_approx_eq(&expected, 4)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = LeakyReluConfig::new().init();
|
||||
assert_eq!(
|
||||
alloc::format!("{}", config),
|
||||
"LeakyRelu {negative_slope: 0.01}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use crate as burn;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::ModuleDisplay;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
use super::Initializer;
|
||||
|
@ -93,7 +91,7 @@ impl<B: Backend> ModuleDisplay for Linear<B> {
|
|||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, d_output] = self.weight.shape().dims;
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
|
@ -196,4 +194,15 @@ mod tests {
|
|||
|
||||
assert_eq!(result_1d.into_data(), result_2d.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = LinearConfig::new(3, 5);
|
||||
let linear = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", linear),
|
||||
"Linear {d_input: 3, d_output: 5, bias: true, params: 20}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate as burn;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
|
||||
use crate::tensor::activation::log_sigmoid;
|
||||
use crate::tensor::{backend::Backend, Int, Tensor};
|
||||
|
@ -59,11 +60,30 @@ impl BinaryCrossEntropyLossConfig {
|
|||
///
|
||||
/// Should be created using [BinaryCrossEntropyLossConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BinaryCrossEntropyLoss<B: Backend> {
|
||||
/// Weights for cross-entropy.
|
||||
pub weights: Option<Tensor<B, 1>>,
|
||||
smoothing: Option<f32>,
|
||||
logits: bool,
|
||||
/// Label smoothing alpha.
|
||||
pub smoothing: Option<f32>,
|
||||
/// Treat the inputs as logits
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("weights", &self.weights)
|
||||
.add("smoothing", &self.smoothing)
|
||||
.add("logits", &self.logits)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BinaryCrossEntropyLoss<B> {
|
||||
|
@ -368,4 +388,16 @@ mod tests {
|
|||
.init(&device)
|
||||
.forward(logits, targets);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config =
|
||||
BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
|
||||
let loss = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", loss),
|
||||
"BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::activation::log_softmax;
|
||||
use crate::tensor::{backend::Backend, Bool, Int, Tensor};
|
||||
use crate::{config::Config, module::Module};
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
|
@ -71,12 +73,39 @@ impl CrossEntropyLossConfig {
|
|||
///
|
||||
/// Should be created using [CrossEntropyLossConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct CrossEntropyLoss<B: Backend> {
|
||||
pad_tokens: Option<Vec<usize>>,
|
||||
/// Pad tokens to ignore in the loss calculation.
|
||||
pub pad_tokens: Option<Vec<usize>>,
|
||||
/// Weights for cross-entropy.
|
||||
pub weights: Option<Tensor<B, 1>>,
|
||||
smoothing: Option<f32>,
|
||||
logits: bool,
|
||||
/// Label smoothing factor.
|
||||
pub smoothing: Option<f32>,
|
||||
/// Use logits as input.
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
|
||||
alloc::format!("Vec<0..{}>", pad_tokens.len())
|
||||
} else {
|
||||
"None".to_string()
|
||||
};
|
||||
|
||||
content
|
||||
.add("pad_tokens", &pad_tokens)
|
||||
.add("weights", &self.weights)
|
||||
.add("smoothing", &self.smoothing)
|
||||
.add("logits", &self.logits)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> CrossEntropyLoss<B> {
|
||||
|
@ -406,4 +435,17 @@ mod tests {
|
|||
|
||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = CrossEntropyLossConfig::new()
|
||||
.with_weights(Some(alloc::vec![3., 7., 0.9]))
|
||||
.with_smoothing(Some(0.5));
|
||||
let loss = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", loss),
|
||||
"CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::{config::Config, module::Module};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::Reduction;
|
||||
|
||||
|
@ -16,15 +16,11 @@ pub struct HuberLossConfig {
|
|||
|
||||
impl HuberLossConfig {
|
||||
/// Initialize [Huber loss](HuberLoss).
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> {
|
||||
// device is not needed as of now, but we might want to prepare some data on it
|
||||
// and its consistent with other loss functions
|
||||
let _ = device;
|
||||
pub fn init(&self) -> HuberLoss {
|
||||
self.assertions();
|
||||
HuberLoss {
|
||||
delta: self.delta,
|
||||
lin_bias: self.delta * self.delta * 0.5,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,14 +48,31 @@ impl HuberLossConfig {
|
|||
/// This loss function is less sensitive to outliers than the mean squared error loss.
|
||||
///
|
||||
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
|
||||
#[derive(Module, Debug)]
|
||||
pub struct HuberLoss<B: Backend> {
|
||||
delta: f32,
|
||||
lin_bias: f32, // delta * delta * 0.5 precomputed
|
||||
_backend: PhantomData<B>,
|
||||
#[derive(Module, Debug, Clone)]
|
||||
#[module(custom_display)]
|
||||
pub struct HuberLoss {
|
||||
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
|
||||
pub delta: f32,
|
||||
/// Precomputed value for the linear bias.
|
||||
pub lin_bias: f32, // delta * delta * 0.5 precomputed
|
||||
}
|
||||
|
||||
impl<B: Backend> HuberLoss<B> {
|
||||
impl ModuleDisplay for HuberLoss {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("delta", &self.delta)
|
||||
.add("lin_bias", &self.lin_bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl HuberLoss {
|
||||
/// Compute the loss element-wise for the predictions and targets, then reduce
|
||||
/// to a single loss value.
|
||||
///
|
||||
|
@ -70,7 +83,7 @@ impl<B: Backend> HuberLoss<B> {
|
|||
/// - predictions: \[...dims\]
|
||||
/// - targets: \[...dims\]
|
||||
/// - output: \[1\]
|
||||
pub fn forward<const D: usize>(
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
|
@ -89,7 +102,7 @@ impl<B: Backend> HuberLoss<B> {
|
|||
/// - predictions: [...dims]
|
||||
/// - targets: [...dims]
|
||||
/// - output: [...dims]
|
||||
pub fn forward_no_reduction<const D: usize>(
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
|
@ -103,7 +116,10 @@ impl<B: Backend> HuberLoss<B> {
|
|||
///
|
||||
/// - residuals: [...dims]
|
||||
/// - output: [...dims]
|
||||
pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> {
|
||||
pub fn forward_residuals<const D: usize, B: Backend>(
|
||||
&self,
|
||||
residuals: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let is_large = residuals.clone().abs().greater_elem(self.delta);
|
||||
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
|
||||
// `sign()` function, in general, suffers from a jump at 0.
|
||||
|
@ -138,7 +154,7 @@ mod tests {
|
|||
let predict = TestTensor::<1>::from_data(predict, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let huber = HuberLossConfig::new(0.5).init(&device);
|
||||
let huber = HuberLossConfig::new(0.5).init();
|
||||
|
||||
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
|
||||
|
@ -166,7 +182,7 @@ mod tests {
|
|||
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
|
||||
let targets = TestAutodiffTensor::from_data(targets, &device);
|
||||
|
||||
let loss = HuberLossConfig::new(0.5).init(&device);
|
||||
let loss = HuberLossConfig::new(0.5).init();
|
||||
let loss = loss.forward_no_reduction(predict.clone(), targets);
|
||||
|
||||
let grads = loss.backward();
|
||||
|
@ -175,4 +191,15 @@ mod tests {
|
|||
let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);
|
||||
grads_predict.to_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = HuberLossConfig::new(0.5);
|
||||
let loss = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", loss),
|
||||
"HuberLoss {delta: 0.5, lin_bias: 0.125}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,26 +1,24 @@
|
|||
use crate::nn::loss::reduction::Reduction;
|
||||
use core::marker::PhantomData;
|
||||
use crate as burn;
|
||||
|
||||
use crate::nn::loss::reduction::Reduction;
|
||||
|
||||
use crate::module::Module;
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
/// Calculate the mean squared error loss from the input logits and the targets.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MseLoss<B: Backend> {
|
||||
backend: PhantomData<B>,
|
||||
}
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct MseLoss;
|
||||
|
||||
impl<B: Backend> Default for MseLoss<B> {
|
||||
impl Default for MseLoss {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MseLoss<B> {
|
||||
impl MseLoss {
|
||||
/// Create the criterion.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
backend: PhantomData,
|
||||
}
|
||||
Self
|
||||
}
|
||||
|
||||
/// Compute the criterion on the input tensor.
|
||||
|
@ -29,7 +27,7 @@ impl<B: Backend> MseLoss<B> {
|
|||
///
|
||||
/// - logits: [batch_size, num_targets]
|
||||
/// - targets: [batch_size, num_targets]
|
||||
pub fn forward<const D: usize>(
|
||||
pub fn forward<const D: usize, B: Backend>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
|
@ -43,7 +41,7 @@ impl<B: Backend> MseLoss<B> {
|
|||
}
|
||||
|
||||
/// Compute the criterion on the input tensor without reducing.
|
||||
pub fn forward_no_reduction<const D: usize>(
|
||||
pub fn forward_no_reduction<const D: usize, B: Backend>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
|
@ -85,4 +83,10 @@ mod tests {
|
|||
let expected = TensorData::from([6.0]);
|
||||
loss_sum.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let loss = MseLoss::new();
|
||||
assert_eq!(alloc::format!("{}", loss), "MseLoss");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,8 +44,10 @@ pub struct BatchNorm<B: Backend, const D: usize> {
|
|||
pub running_mean: RunningState<Tensor<B, 1>>,
|
||||
/// The running variance.
|
||||
pub running_var: RunningState<Tensor<B, 1>>,
|
||||
momentum: f64,
|
||||
epsilon: f64,
|
||||
/// Momentum used to update the metrics.
|
||||
pub momentum: f64,
|
||||
/// A value required for numerical stability.
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl BatchNormConfig {
|
||||
|
@ -424,4 +426,15 @@ mod tests_2d {
|
|||
device,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let batch_norm =
|
||||
BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", batch_norm),
|
||||
"BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::nn::Initializer;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
|
@ -36,16 +37,37 @@ pub struct GroupNormConfig {
|
|||
///
|
||||
/// Should be created using [GroupNormConfig](GroupNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct GroupNorm<B: Backend> {
|
||||
/// The learnable weight
|
||||
pub gamma: Option<Param<Tensor<B, 1>>>,
|
||||
/// The learnable bias
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
/// The number of groups to separate the channels into
|
||||
pub num_groups: usize,
|
||||
/// The number of channels expected in the input
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
pub(crate) num_groups: usize,
|
||||
pub(crate) num_channels: usize,
|
||||
pub(crate) epsilon: f64,
|
||||
pub(crate) affine: bool,
|
||||
impl<B: Backend> ModuleDisplay for GroupNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("num_groups", &self.num_groups)
|
||||
.add("num_channels", &self.num_channels)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.add("affine", &self.affine)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
|
@ -169,6 +191,7 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::tensor::TensorData;
|
||||
use crate::TestBackend;
|
||||
use alloc::format;
|
||||
|
||||
#[test]
|
||||
fn group_norm_forward_affine_false() {
|
||||
|
@ -292,4 +315,15 @@ mod tests {
|
|||
]);
|
||||
output.to_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = GroupNormConfig::new(3, 6);
|
||||
let group_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", group_norm),
|
||||
"GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::module::{Module, Param};
|
||||
use crate::nn::norm::group_norm;
|
||||
use crate::nn::Initializer;
|
||||
|
@ -25,15 +26,34 @@ pub struct InstanceNormConfig {
|
|||
///
|
||||
/// Should be created using [InstanceNormConfig](InstanceNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct InstanceNorm<B: Backend> {
|
||||
/// The learnable weight
|
||||
pub gamma: Option<Param<Tensor<B, 1>>>,
|
||||
/// The learnable bias
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
/// The number of channels expected in the input
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
num_channels: usize,
|
||||
epsilon: f64,
|
||||
affine: bool,
|
||||
impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("num_channels", &self.num_channels)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.add("affine", &self.affine)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl InstanceNormConfig {
|
||||
|
@ -83,6 +103,7 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::tensor::TensorData;
|
||||
use crate::TestBackend;
|
||||
use alloc::format;
|
||||
|
||||
#[test]
|
||||
fn instance_norm_forward_affine_false() {
|
||||
|
@ -187,4 +208,15 @@ mod tests {
|
|||
]);
|
||||
output.to_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = InstanceNormConfig::new(6);
|
||||
let instance_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", instance_norm),
|
||||
"InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Content;
|
||||
use crate::module::DisplaySettings;
|
||||
use crate::module::Module;
|
||||
use crate::module::ModuleDisplay;
|
||||
|
@ -33,9 +34,10 @@ pub struct LayerNormConfig {
|
|||
#[module(custom_display)]
|
||||
pub struct LayerNorm<B: Backend> {
|
||||
/// The learnable weight.
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// The learnable bias.
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
pub beta: Param<Tensor<B, 1>>,
|
||||
/// A value required for numerical stability.
|
||||
epsilon: f64,
|
||||
}
|
||||
|
||||
|
@ -80,7 +82,7 @@ impl<B: Backend> ModuleDisplay for LayerNorm<B> {
|
|||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_model] = self.gamma.shape().dims;
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
|
@ -93,6 +95,7 @@ impl<B: Backend> ModuleDisplay for LayerNorm<B> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::TensorData;
|
||||
use alloc::format;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::{TestAutodiffBackend, TestBackend};
|
||||
|
@ -157,4 +160,15 @@ mod tests {
|
|||
let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
|
||||
tensor_2_grad.to_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = LayerNormConfig::new(6);
|
||||
let layer_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", layer_norm),
|
||||
"LayerNorm {d_model: 6, epsilon: 0.00001, params: 12}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate as burn;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -48,11 +49,12 @@ impl RmsNormConfig {
|
|||
///
|
||||
/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct RmsNorm<B: Backend> {
|
||||
/// The learnable parameter to scale the normalized tensor
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// A value required for numerical stability
|
||||
epsilon: f64,
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> RmsNorm<B> {
|
||||
|
@ -71,11 +73,28 @@ impl<B: Backend> RmsNorm<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for RmsNorm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_model] = self.gamma.shape().dims;
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("epsilon", &self.epsilon)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::TensorData;
|
||||
use crate::TestBackend;
|
||||
use alloc::format;
|
||||
|
||||
#[test]
|
||||
fn rms_norm_forward() {
|
||||
|
@ -95,4 +114,15 @@ mod tests {
|
|||
]);
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = RmsNormConfig::new(6);
|
||||
let layer_norm = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
format!("{}", layer_norm),
|
||||
"RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
|
@ -18,8 +19,22 @@ pub struct AdaptiveAvgPool1dConfig {
|
|||
///
|
||||
/// Should be created with [AdaptiveAvgPool1dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AdaptiveAvgPool1d {
|
||||
output_size: usize,
|
||||
/// The size of the output.
|
||||
pub output_size: usize,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AdaptiveAvgPool1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content.add("output_size", &self.output_size).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool1dConfig {
|
||||
|
@ -44,3 +59,19 @@ impl AdaptiveAvgPool1d {
|
|||
adaptive_avg_pool1d(input, self.output_size)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AdaptiveAvgPool1dConfig::new(3);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"AdaptiveAvgPool1d {output_size: 3}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
|
@ -18,8 +19,24 @@ pub struct AdaptiveAvgPool2dConfig {
|
|||
///
|
||||
/// Should be created with [AdaptiveAvgPool2dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AdaptiveAvgPool2d {
|
||||
output_size: [usize; 2],
|
||||
/// The size of the output.
|
||||
pub output_size: [usize; 2],
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AdaptiveAvgPool2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let output_size = alloc::format!("{:?}", self.output_size);
|
||||
|
||||
content.add("output_size", &output_size).optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AdaptiveAvgPool2dConfig {
|
||||
|
@ -44,3 +61,19 @@ impl AdaptiveAvgPool2d {
|
|||
adaptive_avg_pool2d(input, self.output_size)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AdaptiveAvgPool2dConfig::new([3, 3]);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"AdaptiveAvgPool2d {output_size: [3, 3]}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -40,11 +41,33 @@ pub struct AvgPool1dConfig {
|
|||
/// [Issue 636](https://github.com/tracel-ai/burn/issues/636)
|
||||
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AvgPool1d {
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
count_include_pad: bool,
|
||||
/// The stride.
|
||||
pub stride: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig1d>,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
pub count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AvgPool1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("stride", &self.stride)
|
||||
.add("padding", &self.padding)
|
||||
.add("count_include_pad", &self.count_include_pad)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AvgPool1dConfig {
|
||||
|
@ -83,3 +106,19 @@ impl AvgPool1d {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AvgPool1dConfig::new(3);
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"AvgPool1d {kernel_size: 3, stride: 1, padding: Valid, count_include_pad: true}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -39,11 +40,33 @@ pub struct AvgPool2dConfig {
|
|||
/// TODO: Add support for `count_include_pad=False`, see
|
||||
/// [Issue 636](https://github.com/tracel-ai/burn/issues/636)
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct AvgPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
count_include_pad: bool,
|
||||
/// Stride of the pooling.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
pub count_include_pad: bool,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for AvgPool2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
|
||||
.add("stride", &alloc::format!("{:?}", &self.stride))
|
||||
.add("padding", &self.padding)
|
||||
.add("count_include_pad", &self.count_include_pad)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl AvgPool2dConfig {
|
||||
|
@ -82,3 +105,20 @@ impl AvgPool2d {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = AvgPool2dConfig::new([3, 3]);
|
||||
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"AvgPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, count_include_pad: true}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -28,11 +29,33 @@ pub struct MaxPool1dConfig {
|
|||
///
|
||||
/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct MaxPool1d {
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: Ignored<PaddingConfig1d>,
|
||||
dilation: usize,
|
||||
/// The stride.
|
||||
pub stride: usize,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig1d>,
|
||||
/// The dilation.
|
||||
pub dilation: usize,
|
||||
}
|
||||
|
||||
impl ModuleDisplay for MaxPool1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &self.kernel_size)
|
||||
.add("stride", &self.stride)
|
||||
.add("padding", &self.padding)
|
||||
.add("dilation", &self.dilation)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl MaxPool1dConfig {
|
||||
|
@ -65,3 +88,20 @@ impl MaxPool1d {
|
|||
max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = MaxPool1dConfig::new(3);
|
||||
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
|
@ -28,11 +29,33 @@ pub struct MaxPool2dConfig {
|
|||
///
|
||||
/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct MaxPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: Ignored<PaddingConfig2d>,
|
||||
dilation: [usize; 2],
|
||||
/// The strides.
|
||||
pub stride: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
/// The dilation.
|
||||
pub dilation: [usize; 2],
|
||||
}
|
||||
|
||||
impl ModuleDisplay for MaxPool2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
|
||||
.add("stride", &alloc::format!("{:?}", &self.stride))
|
||||
.add("padding", &self.padding)
|
||||
.add("dilation", &alloc::format!("{:?}", &self.dilation))
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl MaxPool2dConfig {
|
||||
|
@ -65,3 +88,20 @@ impl MaxPool2d {
|
|||
max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = MaxPool2dConfig::new([3, 3]);
|
||||
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,8 @@ use alloc::vec::Vec;
|
|||
|
||||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorData;
|
||||
|
@ -40,8 +41,31 @@ pub struct PositionalEncodingConfig {
|
|||
///
|
||||
/// Should be created using [PositionalEncodingConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PositionalEncoding<B: Backend> {
|
||||
sinusoids: Tensor<B, 3>,
|
||||
/// The sinusoids used to add positional information to the input embeddings.
|
||||
pub sinusoids: Tensor<B, 3>,
|
||||
/// The maximum sequence size to use.
|
||||
pub max_sequence_size: usize,
|
||||
/// Max time scale to use.
|
||||
pub max_timescale: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [_, _, d_model] = self.sinusoids.shape().dims;
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("max_sequence_size", &self.max_sequence_size)
|
||||
.add("max_timescale", &self.max_timescale)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl PositionalEncodingConfig {
|
||||
|
@ -55,7 +79,11 @@ impl PositionalEncodingConfig {
|
|||
)
|
||||
.unsqueeze::<3>();
|
||||
|
||||
PositionalEncoding { sinusoids }
|
||||
PositionalEncoding {
|
||||
sinusoids,
|
||||
max_sequence_size: self.max_sequence_size,
|
||||
max_timescale: self.max_timescale,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -245,4 +273,15 @@ mod tests {
|
|||
let input = Tensor::zeros([1, 6_000, d_model], &device);
|
||||
let _output = pe.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = PositionalEncodingConfig::new(4);
|
||||
let pe = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", pe),
|
||||
"PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -9,11 +9,33 @@ use crate::tensor::Tensor;
|
|||
///
|
||||
/// Should be created using [PReluConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PRelu<B: Backend> {
|
||||
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
|
||||
/// be the same as number of channels in the input tensor
|
||||
pub alpha: Param<Tensor<B, 1>>,
|
||||
|
||||
/// Alpha value for the PRelu layer
|
||||
pub alpha_value: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PRelu<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [num_parameters] = self.alpha.shape().dims;
|
||||
|
||||
content
|
||||
.add("num_parameters", &num_parameters)
|
||||
.add("alpha_value", &self.alpha_value)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PReluConfig {
|
||||
|
@ -24,12 +46,14 @@ pub struct PReluConfig {
|
|||
#[config(default = "0.25")]
|
||||
pub alpha: f64,
|
||||
}
|
||||
|
||||
impl PReluConfig {
|
||||
/// Initialize a new [Parametric Relu](PRelu) Layer
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
|
||||
PRelu {
|
||||
// alpha is a tensor of length num_parameters
|
||||
alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
|
||||
alpha_value: self.alpha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -47,3 +71,19 @@ impl<B: Backend> PRelu<B> {
|
|||
crate::tensor::activation::prelu(input, self.alpha.val())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = PReluConfig::new().init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::tensor::Tensor;
|
|||
/// See also [relu](burn::tensor::activation::relu)
|
||||
///
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Relu {}
|
||||
pub struct Relu;
|
||||
|
||||
impl Relu {
|
||||
/// Create the module.
|
||||
|
@ -25,3 +25,15 @@ impl Relu {
|
|||
crate::tensor::activation::relu(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Relu::new();
|
||||
|
||||
assert_eq!(alloc::format!("{}", layer), "Relu");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::nn::rnn::gate_controller;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::activation;
|
||||
|
@ -30,11 +31,35 @@ pub struct GruConfig {
|
|||
///
|
||||
/// Should be created with [GruConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Gru<B: Backend> {
|
||||
update_gate: GateController<B>,
|
||||
reset_gate: GateController<B>,
|
||||
new_gate: GateController<B>,
|
||||
d_hidden: usize,
|
||||
/// The update gate controller.
|
||||
pub update_gate: GateController<B>,
|
||||
/// The reset gate controller.
|
||||
pub reset_gate: GateController<B>,
|
||||
/// The new gate controller.
|
||||
pub new_gate: GateController<B>,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Gru<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self.update_gate.input_transform.weight.shape().dims;
|
||||
let bias = self.update_gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl GruConfig {
|
||||
|
@ -274,4 +299,16 @@ mod tests {
|
|||
|
||||
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = GruConfig::new(2, 8, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, ModuleDisplay};
|
||||
use crate::nn::rnn::gate_controller::GateController;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::activation;
|
||||
|
@ -43,6 +44,7 @@ pub struct LstmConfig {
|
|||
///
|
||||
/// Should be created with [LstmConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Lstm<B: Backend> {
|
||||
/// The input gate regulates which information to update and store in the cell state at each time step.
|
||||
pub input_gate: GateController<B>,
|
||||
|
@ -52,7 +54,27 @@ pub struct Lstm<B: Backend> {
|
|||
pub output_gate: GateController<B>,
|
||||
/// The cell gate is used to compute the cell state that stores and carries information through time.
|
||||
pub cell_gate: GateController<B>,
|
||||
d_hidden: usize,
|
||||
/// The hidden state of the LSTM.
|
||||
pub d_hidden: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for Lstm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self.input_gate.input_transform.weight.shape().dims;
|
||||
let bias = self.input_gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl LstmConfig {
|
||||
|
@ -195,12 +217,33 @@ pub struct BiLstmConfig {
|
|||
///
|
||||
/// Should be created with [BiLstmConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct BiLstm<B: Backend> {
|
||||
/// LSTM for the forward direction.
|
||||
pub forward: Lstm<B>,
|
||||
/// LSTM for the reverse direction.
|
||||
pub reverse: Lstm<B>,
|
||||
d_hidden: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for BiLstm<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, _] = self.forward.input_gate.input_transform.weight.shape().dims;
|
||||
let bias = self.forward.input_gate.input_transform.bias.is_some();
|
||||
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_hidden", &self.d_hidden)
|
||||
.add("bias", &bias)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl BiLstmConfig {
|
||||
|
@ -693,4 +736,28 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&expected_cn_without_init_state, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_lstm() {
|
||||
let config = LstmConfig::new(2, 3, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_bilstm() {
|
||||
let config = BiLstmConfig::new(2, 3, true);
|
||||
|
||||
let layer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Int;
|
||||
use crate::tensor::Tensor;
|
||||
|
@ -74,7 +74,11 @@ impl RotaryEncodingConfig {
|
|||
.repeat(2, 2)
|
||||
.reshape([self.max_sequence_length, self.d_model, 2]);
|
||||
|
||||
RotaryEncoding { freq_complex }
|
||||
RotaryEncoding {
|
||||
freq_complex,
|
||||
max_sequence_length: self.max_sequence_length,
|
||||
theta: self.theta,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,9 +91,31 @@ impl RotaryEncodingConfig {
|
|||
///
|
||||
/// Should be created using [RotaryEncodingConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct RotaryEncoding<B: Backend> {
|
||||
/// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
|
||||
freq_complex: Tensor<B, 3>,
|
||||
pub freq_complex: Tensor<B, 3>,
|
||||
/// Maximum sequence length of input
|
||||
pub max_sequence_length: usize,
|
||||
/// Scaling factor for frequency computation.
|
||||
pub theta: f32,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [_, _, d_model] = self.freq_complex.shape().dims;
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("max_sequence_length", &self.max_sequence_length)
|
||||
.add("theta", &self.theta)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
|
@ -238,4 +264,15 @@ mod tests {
|
|||
let input = Tensor::zeros([1, 5, d_model], &device);
|
||||
let _output = pe.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = RotaryEncodingConfig::new(10, 4);
|
||||
let pe = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", pe),
|
||||
"RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::activation::silu;
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
|
@ -31,6 +31,7 @@ pub struct SwiGluConfig {
|
|||
///
|
||||
/// Should be created with [SwiGluConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct SwiGlu<B: Backend> {
|
||||
/// The inner linear layer for Swish activation function
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
|
@ -40,6 +41,23 @@ pub struct SwiGlu<B: Backend> {
|
|||
pub linear_outer: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for SwiGlu<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_input, d_output] = self.linear_inner.weight.shape().dims;
|
||||
content
|
||||
.add("d_input", &d_input)
|
||||
.add("d_output", &d_output)
|
||||
.add("bias", &self.linear_inner.bias.is_some())
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl SwiGluConfig {
|
||||
/// Initialize a new [SwiGLU](SwiGlu) activation layer.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
|
||||
|
@ -112,4 +130,15 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&expected_output.to_data(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = SwiGluConfig::new(3, 5);
|
||||
let swiglu = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", swiglu),
|
||||
"SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::tensor::Tensor;
|
|||
/// Applies the tanh activation function element-wise
|
||||
/// See also [tanh](burn::tensor::activation::tanh)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Tanh {}
|
||||
pub struct Tanh;
|
||||
|
||||
impl Tanh {
|
||||
/// Create the module.
|
||||
|
@ -24,3 +24,15 @@ impl Tanh {
|
|||
crate::tensor::activation::tanh(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let layer = Tanh::new();
|
||||
|
||||
assert_eq!(alloc::format!("{}", layer), "Tanh");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
use crate::tensor::Bool;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::tensor::Bool;
|
||||
use crate::{
|
||||
self as burn,
|
||||
nn::{attention::MhaCache, cache::TensorCache, Initializer},
|
||||
};
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
|
||||
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
|
||||
|
@ -57,8 +57,51 @@ pub struct TransformerDecoderConfig {
|
|||
///
|
||||
/// Should be created using [TransformerDecoderConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct TransformerDecoder<B: Backend> {
|
||||
layers: Vec<TransformerDecoderLayer<B>>,
|
||||
/// Transformer decoder layers.
|
||||
pub layers: Vec<TransformerDecoderLayer<B>>,
|
||||
|
||||
/// The size of the model.
|
||||
pub d_model: usize,
|
||||
|
||||
/// The size of the position-wise feed-forward network.
|
||||
pub d_ff: usize,
|
||||
|
||||
/// The number of attention heads.
|
||||
pub n_heads: usize,
|
||||
|
||||
/// The number of layers.
|
||||
pub n_layers: usize,
|
||||
|
||||
/// The dropout rate. Default: 0.1
|
||||
pub dropout: f64,
|
||||
|
||||
/// Layer norm will be applied first instead of after the other modules.
|
||||
pub norm_first: bool,
|
||||
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("d_model", &self.d_model)
|
||||
.add("d_ff", &self.d_ff)
|
||||
.add("n_heads", &self.n_heads)
|
||||
.add("n_layers", &self.n_layers)
|
||||
.add("dropout", &self.dropout)
|
||||
.add("norm_first", &self.norm_first)
|
||||
.add("quiet_softmax", &self.quiet_softmax)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerDecoderConfig {
|
||||
|
@ -68,7 +111,16 @@ impl TransformerDecoderConfig {
|
|||
.map(|_| TransformerDecoderLayer::new(self, device))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
TransformerDecoder { layers }
|
||||
TransformerDecoder {
|
||||
layers,
|
||||
d_model: self.d_model,
|
||||
d_ff: self.d_ff,
|
||||
n_heads: self.n_heads,
|
||||
n_layers: self.n_layers,
|
||||
dropout: self.dropout,
|
||||
norm_first: self.norm_first,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -473,4 +525,16 @@ mod tests {
|
|||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = TransformerDecoderConfig::new(2, 4, 2, 3);
|
||||
let transformer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", transformer),
|
||||
"TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
|
||||
dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
use crate::tensor::Bool;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::{
|
||||
self as burn,
|
||||
nn::{attention::MhaCache, cache::TensorCache, Initializer},
|
||||
};
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
|
||||
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
|
||||
|
@ -57,8 +56,51 @@ pub struct TransformerEncoderConfig {
|
|||
///
|
||||
/// Should be created using [TransformerEncoderConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct TransformerEncoder<B: Backend> {
|
||||
layers: Vec<TransformerEncoderLayer<B>>,
|
||||
/// The transformer encoder layers.
|
||||
pub layers: Vec<TransformerEncoderLayer<B>>,
|
||||
|
||||
/// The size of the model.
|
||||
pub d_model: usize,
|
||||
|
||||
/// The size of the position-wise feed-forward network.
|
||||
pub d_ff: usize,
|
||||
|
||||
/// The number of attention heads.
|
||||
pub n_heads: usize,
|
||||
|
||||
/// The number of layers.
|
||||
pub n_layers: usize,
|
||||
|
||||
/// The dropout rate. Default: 0.1
|
||||
pub dropout: f64,
|
||||
|
||||
/// Layer norm will be applied first instead of after the other modules.
|
||||
pub norm_first: bool,
|
||||
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
pub quiet_softmax: bool,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("d_model", &self.d_model)
|
||||
.add("d_ff", &self.d_ff)
|
||||
.add("n_heads", &self.n_heads)
|
||||
.add("n_layers", &self.n_layers)
|
||||
.add("dropout", &self.dropout)
|
||||
.add("norm_first", &self.norm_first)
|
||||
.add("quiet_softmax", &self.quiet_softmax)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
|
||||
|
@ -98,7 +140,16 @@ impl TransformerEncoderConfig {
|
|||
.map(|_| TransformerEncoderLayer::new(self, device))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
TransformerEncoder { layers }
|
||||
TransformerEncoder {
|
||||
layers,
|
||||
d_model: self.d_model,
|
||||
d_ff: self.d_ff,
|
||||
n_heads: self.n_heads,
|
||||
n_layers: self.n_layers,
|
||||
dropout: self.dropout,
|
||||
norm_first: self.norm_first,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -392,4 +443,16 @@ mod tests {
|
|||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = TransformerEncoderConfig::new(2, 4, 2, 3);
|
||||
let transformer = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", transformer),
|
||||
"TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
|
||||
n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
use crate::nn::Initializer;
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
@ -36,11 +36,34 @@ pub struct PositionWiseFeedForwardConfig {
|
|||
///
|
||||
/// Should be created using [PositionWiseFeedForwardConfig]
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct PositionWiseFeedForward<B: Backend> {
|
||||
linear_inner: Linear<B>,
|
||||
linear_outer: Linear<B>,
|
||||
dropout: Dropout,
|
||||
gelu: Gelu,
|
||||
/// Linear layer with `d_model` input features and `d_ff` output features.
|
||||
pub linear_inner: Linear<B>,
|
||||
/// Linear layer with `d_ff` input features and `d_model` output features.
|
||||
pub linear_outer: Linear<B>,
|
||||
/// Dropout layer.
|
||||
pub dropout: Dropout,
|
||||
/// GELU activation function.
|
||||
pub gelu: Gelu,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
let [d_model, dff] = self.linear_inner.weight.shape().dims;
|
||||
|
||||
content
|
||||
.add("d_model", &d_model)
|
||||
.add("d_ff", &dff)
|
||||
.add("prob", &self.dropout.prob)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl PositionWiseFeedForwardConfig {
|
||||
|
@ -74,3 +97,20 @@ impl<B: Backend> PositionWiseFeedForward<B> {
|
|||
self.linear_outer.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = PositionWiseFeedForwardConfig::new(2, 4);
|
||||
let pwff = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", pwff),
|
||||
"PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Ignored, Module};
|
||||
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::module::unfold4d;
|
||||
use burn_tensor::ops::UnfoldOptions;
|
||||
|
@ -27,15 +28,43 @@ pub struct Unfold4dConfig {
|
|||
///
|
||||
/// Should be created with [Unfold4dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Unfold4d {
|
||||
config: Ignored<Unfold4dConfig>,
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// The padding configuration.
|
||||
pub padding: [usize; 2],
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Unfold4d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
|
||||
.add("stride", &alloc::format!("{:?}", &self.stride))
|
||||
.add("dilation", &alloc::format!("{:?}", &self.dilation))
|
||||
.add("padding", &alloc::format!("{:?}", &self.padding))
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl Unfold4dConfig {
|
||||
/// Initializes a new [Unfold4d] module.
|
||||
pub fn init(&self) -> Unfold4d {
|
||||
Unfold4d {
|
||||
config: Ignored(self.clone()),
|
||||
kernel_size: self.kernel_size,
|
||||
stride: self.stride,
|
||||
dilation: self.dilation,
|
||||
padding: self.padding,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -52,12 +81,24 @@ impl Unfold4d {
|
|||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
|
||||
unfold4d(
|
||||
input,
|
||||
self.config.kernel_size,
|
||||
UnfoldOptions::new(
|
||||
self.config.stride,
|
||||
self.config.padding,
|
||||
self.config.dilation,
|
||||
),
|
||||
self.kernel_size,
|
||||
UnfoldOptions::new(self.stride, self.padding, self.dilation),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Unfold4dConfig::new([3, 3]);
|
||||
let unfold = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", unfold),
|
||||
"Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ impl<B: Backend> Net<B> {
|
|||
pub fn init(device: &B::Device) -> Self {
|
||||
let fc1 = LinearConfig::new(2, 3).init(device);
|
||||
let fc2 = LinearConfig::new(3, 4).init(device);
|
||||
let relu = Relu::default();
|
||||
let relu = Relu;
|
||||
|
||||
Self { fc1, fc2, relu }
|
||||
}
|
||||
|
|
|
@ -101,6 +101,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConvTranspose2dNode {
|
|||
groups: ConstantRecord::new(),
|
||||
padding: [ConstantRecord::new(); 2],
|
||||
padding_out: [ConstantRecord::new(); 2],
|
||||
channels: [ConstantRecord::new(); 2],
|
||||
};
|
||||
|
||||
let item = Record::into_item::<PS>(record);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::{Node, NodeCodegen, SerializationBackend};
|
||||
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
|
||||
use burn::{
|
||||
module::{Param, ParamId},
|
||||
module::{ConstantRecord, Param, ParamId},
|
||||
nn::{PReluConfig, PReluRecord},
|
||||
record::{PrecisionSettings, Record},
|
||||
tensor::{Tensor, TensorData},
|
||||
|
@ -70,6 +70,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode {
|
|||
ParamId::new(),
|
||||
Tensor::from_data(self.alpha.clone().convert::<PS::FloatElem>(), &device),
|
||||
),
|
||||
alpha_value: ConstantRecord,
|
||||
};
|
||||
|
||||
let item = Record::into_item::<PS>(record);
|
||||
|
|
|
@ -122,7 +122,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
<LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
|
||||
LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
{
|
||||
log::info!("Fitting {}", self.model.to_string());
|
||||
log::info!("Fitting the model:\n {}", self.model.to_string());
|
||||
// The reference model is always on the first device provided.
|
||||
if let Some(device) = self.devices.first() {
|
||||
self.model = self.model.fork(device);
|
||||
|
|
Loading…
Reference in New Issue