From 8a88a868ee4ed86eb669622ed4f70ff5cd007654 Mon Sep 17 00:00:00 2001 From: Mathias Insley <55096933+agelas@users.noreply.github.com> Date: Tue, 6 Jun 2023 11:33:22 -0700 Subject: [PATCH] Feat/lstm (#370) --- burn-core/src/nn/linear.rs | 4 +- burn-core/src/nn/lstm/gate_controller.rs | 87 ++++++ burn-core/src/nn/lstm/lstm.rs | 322 +++++++++++++++++++++++ burn-core/src/nn/lstm/mod.rs | 6 + burn-core/src/nn/mod.rs | 2 + burn-tensor/src/tensor/api/base.rs | 44 ++++ burn-tensor/src/tensor/api/check.rs | 17 ++ burn-tensor/src/tests/mod.rs | 1 + burn-tensor/src/tests/ops/mod.rs | 1 + burn-tensor/src/tests/ops/squeeze.rs | 37 +++ 10 files changed, 519 insertions(+), 2 deletions(-) create mode 100644 burn-core/src/nn/lstm/gate_controller.rs create mode 100644 burn-core/src/nn/lstm/lstm.rs create mode 100644 burn-core/src/nn/lstm/mod.rs create mode 100644 burn-tensor/src/tests/ops/squeeze.rs diff --git a/burn-core/src/nn/linear.rs b/burn-core/src/nn/linear.rs index 97d3843b2..059f7ae86 100644 --- a/burn-core/src/nn/linear.rs +++ b/burn-core/src/nn/linear.rs @@ -37,8 +37,8 @@ pub struct LinearConfig { /// `U(-k, k)`, where `k = sqrt(1 / d_input)` #[derive(Module, Debug)] pub struct Linear { - weight: Param>, - bias: Option>>, + pub(crate) weight: Param>, + pub(crate) bias: Option>>, } impl LinearConfig { diff --git a/burn-core/src/nn/lstm/gate_controller.rs b/burn-core/src/nn/lstm/gate_controller.rs new file mode 100644 index 000000000..c063301ac --- /dev/null +++ b/burn-core/src/nn/lstm/gate_controller.rs @@ -0,0 +1,87 @@ +use crate as burn; + +use crate::module::Module; +use crate::nn::Initializer; +use crate::nn::Linear; +use crate::nn::LinearConfig; +use burn_tensor::backend::Backend; + +/// A GateController represents a gate in an LSTM cell. An +/// LSTM cell generally contains three gates: an input gate, +/// forget gate, and cell gate. +/// +/// An Lstm gate is modeled as two linear transformations. +/// The results of these transformations are used to calculate +/// the gate's output. +#[derive(Module, Debug)] +pub struct GateController { + /// Represents the affine transformation applied to input vector + pub(crate) input_transform: Linear, + /// Represents the affine transformation applied to the hidden state + pub(crate) hidden_transform: Linear, +} + +impl GateController { + /// Initialize a new [gate_controller](GateController) module. + pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self { + Self { + input_transform: LinearConfig { + d_input, + d_output, + bias, + initializer: initializer.clone(), + } + .init(), + hidden_transform: LinearConfig { + d_input: d_output, + d_output, + bias, + initializer, + } + .init(), + } + } + + /// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord). + pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord) -> Self { + let l1 = LinearConfig::init_with(linear_config, record.input_transform); + let l2 = LinearConfig::init_with(linear_config, record.hidden_transform); + + Self { + input_transform: l1, + hidden_transform: l2, + } + } + + /// Used to initialize a gate controller with known weight layers, + /// allowing for predictable behavior. Used only for testing in + /// lstm. + #[cfg(test)] + pub fn create_with_weights( + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + input_record: crate::nn::LinearRecord, + hidden_record: crate::nn::LinearRecord, + ) -> Self { + let l1 = LinearConfig { + d_input, + d_output, + bias, + initializer: initializer.clone(), + } + .init_with(input_record); + let l2 = LinearConfig { + d_input, + d_output, + bias, + initializer, + } + .init_with(hidden_record); + Self { + input_transform: l1, + hidden_transform: l2, + } + } +} diff --git a/burn-core/src/nn/lstm/lstm.rs b/burn-core/src/nn/lstm/lstm.rs new file mode 100644 index 000000000..899218b2b --- /dev/null +++ b/burn-core/src/nn/lstm/lstm.rs @@ -0,0 +1,322 @@ +use burn_tensor::activation; + +use crate as burn; + +use crate::config::Config; +use crate::module::Module; +use crate::nn::lstm::gate_controller; +use crate::nn::Initializer; +use crate::nn::LinearConfig; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; + +use super::gate_controller::GateController; + +#[derive(Config)] +pub struct LstmConfig { + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the Lstm transformation. + pub bias: bool, + /// Lstm initializer + /// TODO: Make default Xavier initialization, which should be + /// a better choice. https://github.com/burn-rs/burn/issues/371 + #[config(default = "Initializer::Uniform(0.0, 1.0)")] + pub initializer: Initializer, + /// The batch size + pub batch_size: usize, +} + +/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. +#[derive(Module, Debug)] +pub struct Lstm { + input_gate: GateController, + forget_gate: GateController, + output_gate: GateController, + cell_gate: GateController, + batch_size: usize, + d_hidden: usize, +} + +impl LstmConfig { + /// Initialize a new [lstm](Lstm) module. + pub fn init(&self) -> Lstm { + let d_output = self.d_hidden; + + let input_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let forget_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let output_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + let cell_gate = gate_controller::GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + ); + + Lstm { + input_gate, + forget_gate, + output_gate, + cell_gate, + batch_size: self.batch_size, + d_hidden: self.d_hidden, + } + } + + /// Initialize a new [lstm](lstm) module with a [record](LstmRecord). + pub fn init_with(&self, record: LstmRecord) -> Lstm { + let linear_config = LinearConfig { + d_input: self.d_input, + d_output: self.d_hidden, + bias: self.bias, + initializer: self.initializer.clone(), + }; + + Lstm { + input_gate: gate_controller::GateController::new_with( + &linear_config, + record.input_gate, + ), + forget_gate: gate_controller::GateController::new_with( + &linear_config, + record.forget_gate, + ), + output_gate: gate_controller::GateController::new_with( + &linear_config, + record.output_gate, + ), + cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate), + batch_size: self.batch_size, + d_hidden: self.d_hidden, + } + } +} + +impl Lstm { + /// Applies the forward pass on the input tensor. This LSTM implementation + /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// producing 3-dimensional tensors where the dimensions represent [batch_size, seq_length, hidden_size]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, seq_length, input_size]. + /// state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape [batch_size, hidden_size]. + /// If no initial state is provided, these tensors are initialized to zeros. + /// + /// Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape [batch_size, seq_length, hidden_size]. + pub fn forward( + &mut self, + batched_input: Tensor, + state: Option<(Tensor, Tensor)>, + ) -> (Tensor, Tensor) { + let seq_length = batched_input.shape().dims[1]; + let mut batched_cell_state = Tensor::zeros([self.batch_size, seq_length, self.d_hidden]); + let mut batched_hidden_state = Tensor::zeros([self.batch_size, seq_length, self.d_hidden]); + + let (mut cell_state, mut hidden_state) = match state { + Some((cell_state, hidden_state)) => (cell_state, hidden_state), + None => ( + Tensor::zeros([self.batch_size, self.d_hidden]), + Tensor::zeros([self.batch_size, self.d_hidden]), + ), + }; + + for t in 0..seq_length { + let indices = Tensor::arange(t..t + 1); + let input_t = batched_input.clone().index_select(1, indices).squeeze(1); + // f(orget)g(ate) tensors + let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); + let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state + + // i(nput)g(ate) tensors + let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); + let add_values = activation::sigmoid(biased_ig_input_sum); + + // o(utput)g(ate) tensors + let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); + let output_values = activation::sigmoid(biased_og_input_sum); + + // c(ell)g(ate) tensors + let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); + let candidate_cell_values = biased_cg_input_sum.tanh(); + + cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; + hidden_state = output_values * cell_state.clone().tanh(); + + // store the state for this timestep + batched_cell_state = batched_cell_state.index_assign( + [0..self.batch_size, t..(t + 1), 0..self.d_hidden], + cell_state.clone().unsqueeze(), + ); + batched_hidden_state = batched_hidden_state.index_assign( + [0..self.batch_size, t..(t + 1), 0..self.d_hidden], + hidden_state.clone().unsqueeze(), + ); + } + + (batched_cell_state, batched_hidden_state) + } + + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + fn gate_product( + &self, + input: &Tensor, + hidden: &Tensor, + gate: &GateController, + ) -> Tensor { + let input_product = input.clone().matmul(gate.input_transform.weight.val()); + let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); + + let input_bias = gate + .input_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + let hidden_bias = gate + .hidden_transform + .bias + .as_ref() + .map(|bias_param| bias_param.val()); + + match (input_bias, hidden_bias) { + (Some(input_bias), Some(hidden_bias)) => { + input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() + } + (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, + (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), + (None, None) => input_product + hidden_product, + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::{module::Param, nn::LinearRecord, TestBackend}; + use burn_tensor::Data; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = LstmConfig::new(5, 5, false, 2); + let lstm = config.init::(); + + lstm.input_gate + .input_transform + .weight + .val() + .to_data() + .assert_in_range(0.0, 1.0); + lstm.forget_gate + .input_transform + .weight + .val() + .to_data() + .assert_in_range(0.0, 1.0); + lstm.output_gate + .input_transform + .weight + .val() + .to_data() + .assert_in_range(0.0, 1.0); + lstm.cell_gate + .input_transform + .weight + .val() + .to_data() + .assert_in_range(0.0, 1.0); + } + + /// Test forward pass with simple input vector + /// + /// f_t = sigmoid(0.7*0 + 0.8*0) = 0.5 + /// i_t = sigmoid(0.5*0.1 + 0.6*0) = sigmoid(0.05) = 0.5123725 + /// o_t = sigmoid(1.1*0.1 + 1.2*0) = sigmoid(0.11) = 0.5274723 + /// c_t = tanh(0.9*0.1 + 1.0*0) = tanh(0.09) = 0.0892937 + + /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 + /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 + #[test] + fn test_forward_single_input_single_feature() { + TestBackend::seed(0); + let config = LstmConfig::new(1, 1, false, 1); + let mut lstm = config.init::(); + + fn create_gate_controller( + weights: f32, + biases: f32, + d_input: usize, + d_output: usize, + bias: bool, + initializer: Initializer, + ) -> GateController { + let record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from([[weights]]))), + bias: Some(Param::from(Tensor::from_data(Data::from([biases])))), + }; + gate_controller::GateController::create_with_weights( + d_input, + d_output, + bias, + initializer, + record.clone(), + record, + ) + } + + lstm.input_gate = + create_gate_controller(0.5, 0.0, 1, 1, false, Initializer::UniformDefault); + lstm.forget_gate = + create_gate_controller(0.7, 0.0, 1, 1, false, Initializer::UniformDefault); + lstm.cell_gate = create_gate_controller(0.9, 0.0, 1, 1, false, Initializer::UniformDefault); + lstm.output_gate = + create_gate_controller(1.1, 0.0, 1, 1, false, Initializer::UniformDefault); + + // single timestep with single feature + let input = Tensor::::from_data(Data::from([[[0.1]]])); + + let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); + let cell_state = cell_state_batch + .index_select(0, Tensor::arange(0..1)) + .squeeze(0); + let hidden_state = hidden_state_batch + .index_select(0, Tensor::arange(0..1)) + .squeeze(0); + cell_state + .to_data() + .assert_approx_eq(&Data::from([[0.046]]), 3); + hidden_state + .to_data() + .assert_approx_eq(&Data::from([[0.024]]), 3) + } +} diff --git a/burn-core/src/nn/lstm/mod.rs b/burn-core/src/nn/lstm/mod.rs new file mode 100644 index 000000000..4ad4c67c2 --- /dev/null +++ b/burn-core/src/nn/lstm/mod.rs @@ -0,0 +1,6 @@ +mod gate_controller; +#[allow(clippy::module_inception)] +pub mod lstm; + +pub use gate_controller::*; +pub use lstm::*; diff --git a/burn-core/src/nn/mod.rs b/burn-core/src/nn/mod.rs index 12c867eac..c060c22ee 100644 --- a/burn-core/src/nn/mod.rs +++ b/burn-core/src/nn/mod.rs @@ -10,6 +10,7 @@ mod embedding; mod gelu; mod initializer; mod linear; +mod lstm; mod norm; mod relu; @@ -18,5 +19,6 @@ pub use embedding::*; pub use gelu::*; pub use initializer::*; pub use linear::*; +pub use lstm::*; pub use norm::*; pub use relu::*; diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index ad614daf7..43bbc4dfa 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -110,6 +110,50 @@ where Tensor::new(K::reshape::(self.primitive, new_dims.into())) } + /// Squeeze the tensor along the given dimension, removing the specified dimension + /// of size one, and effectively reducing the rank of the tensor by one. + /// + /// # Arguments + /// + /// - `dim`: The dimension to be squeezed. + /// + /// # Type Parameters + /// + /// - 'D2': The resulting number of dimensions in the squeezed tensor. + /// + /// # Returns + /// + /// A new `Tensor` instance with the specified dimenension removed. + /// + /// # Example + /// + /// ```rust + /// + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([2, 1, 4])); + /// + /// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1 + /// let squeezed_tensor: Tensor:: = tensor.squeeze(1); + /// + /// // Resulting tensor will have dimensions (2, 4) + /// println!("{:?}", squeezed_tensor.shape()); + /// } + /// ``` + pub fn squeeze(self, dim: usize) -> Tensor { + check!(TensorCheck::squeeze::(dim, &self.shape().dims)); + + let current_dims = self.shape().dims; + let mut new_dims: [usize; D2] = [0; D2]; + + new_dims[..dim].copy_from_slice(¤t_dims[..dim]); + new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]); + + Tensor::new(K::reshape::(self.primitive, new_dims.into())) + } + /// Unsqueeze the current tensor. Create new dimensions to fit the given size. /// /// # Panics diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index 24ec25f85..a19095d7d 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -133,6 +133,23 @@ impl TensorCheck { check } + pub(crate) fn squeeze(dim: usize, tensor_dims: &[usize]) -> Self { + let mut check = Self::Ok; + // This should actually be to check that the dimension to squeeze + // has a size of 1 + if tensor_dims[dim] != 1 { + check = check.register( + "Squeeze", + TensorError::new(format!( + "Can't squeeze dimension {} because its size is not 1", + dim + )), + ); + } + + check + } + pub(crate) fn unsqueeze() -> Self { let mut check = Self::Ok; if D2 < D1 { diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index 8399e8c01..12895703c 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -47,6 +47,7 @@ macro_rules! testgen_all { burn_tensor::testgen_reshape!(); burn_tensor::testgen_flatten!(); burn_tensor::testgen_sin!(); + burn_tensor::testgen_squeeze!(); burn_tensor::testgen_tanh!(); burn_tensor::testgen_sub!(); burn_tensor::testgen_transpose!(); diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs index 4293caf7f..df2f6deb2 100644 --- a/burn-tensor/src/tests/ops/mod.rs +++ b/burn-tensor/src/tests/ops/mod.rs @@ -22,6 +22,7 @@ mod repeat; mod reshape; mod sin; mod sqrt; +mod squeeze; mod sub; mod tanh; mod transpose; diff --git a/burn-tensor/src/tests/ops/squeeze.rs b/burn-tensor/src/tests/ops/squeeze.rs new file mode 100644 index 000000000..d8de064bd --- /dev/null +++ b/burn-tensor/src/tests/ops/squeeze.rs @@ -0,0 +1,37 @@ +#[burn_tensor_testgen::testgen(squeeze)] +mod tests { + use super::*; + use burn_tensor::{Data, Shape, Tensor}; + + /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. + #[test] + fn should_squeeze() { + let tensor = Tensor::::ones(Shape::new([2, 1, 4])); + let squeezed_tensor: Tensor = tensor.squeeze(1); + let expected_shape = Shape::new([2, 4]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. + #[test] + fn should_squeeze_first() { + let tensor = Tensor::::ones(Shape::new([1, 3, 4, 5])); + let squeezed_tensor: Tensor = tensor.squeeze(0); + let expected_shape = Shape::new([3, 4, 5]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. + #[test] + fn should_squeeze_last() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 1])); + let squeezed_tensor: Tensor = tensor.squeeze(3); + let expected_shape = Shape::new([2, 3, 4]); + assert_eq!(squeezed_tensor.shape(), expected_shape); + } + /// Test if the function panics when the squeezed dimension is not of size 1. + #[test] + #[should_panic] + fn should_squeeze_panic() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let squeezed_tensor: Tensor = tensor.squeeze(2); + } +}