This commit is contained in:
Mathias Insley 2023-06-09 15:56:40 -07:00 committed by GitHub
parent 8c9802c363
commit f8cd38c071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 269 additions and 15 deletions

View File

@ -10,15 +10,15 @@ mod embedding;
mod gelu;
mod initializer;
mod linear;
mod lstm;
mod norm;
mod relu;
mod rnn;
pub use dropout::*;
pub use embedding::*;
pub use gelu::*;
pub use initializer::*;
pub use linear::*;
pub use lstm::*;
pub use norm::*;
pub use relu::*;
pub use rnn::*;

255
burn-core/src/nn/rnn/gru.rs Normal file
View File

@ -0,0 +1,255 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::rnn::gate_controller;
use crate::nn::Initializer;
use crate::nn::LinearConfig;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::activation;
use super::gate_controller::GateController;
#[derive(Config)]
pub struct GruConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the hidden state.
pub d_hidden: usize,
/// If a bias should be applied during the Gru transformation.
pub bias: bool,
/// Gru initializer
/// TODO: Make default Xavier initialization. https://github.com/burn-rs/burn/issues/371
#[config(default = "Initializer::Uniform(0.0, 1.0)")]
pub initializer: Initializer,
/// The batch size.
pub batch_size: usize,
}
/// The Gru module. This implementation is for a unidirectional, stateless, Gru.
#[derive(Module, Debug)]
pub struct Gru<B: Backend> {
update_gate: GateController<B>,
reset_gate: GateController<B>,
new_gate: GateController<B>,
batch_size: usize,
d_hidden: usize,
}
impl GruConfig {
/// Initialize a new [gru](Gru) module.
pub fn init<B: Backend>(&self) -> Gru<B> {
let d_output = self.d_hidden;
let update_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let reset_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let new_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
Gru {
update_gate,
reset_gate,
new_gate,
batch_size: self.batch_size,
d_hidden: self.d_hidden,
}
}
/// Initialize a new [gru](Gru) module.
pub fn init_with<B: Backend>(self, record: GruRecord<B>) -> Gru<B> {
let linear_config = LinearConfig {
d_input: self.d_input,
d_output: self.d_hidden,
bias: self.bias,
initializer: self.initializer.clone(),
};
Gru {
update_gate: gate_controller::GateController::new_with(
&linear_config,
record.update_gate,
),
reset_gate: gate_controller::GateController::new_with(
&linear_config,
record.reset_gate,
),
new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate),
batch_size: self.batch_size,
d_hidden: self.d_hidden,
}
}
}
impl<B: Backend> Gru<B> {
/// Applies the forward pass on the input tensor. This GRU implementation
/// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size].
///
/// Parameters:
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size].
/// state: An optional tensor representing an initial cell state with the same dimensions
/// as batched_input. If none is provided, one will be generated.
///
/// Returns:
/// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size].
pub fn forward(
&mut self,
batched_input: Tensor<B, 3>,
state: Option<Tensor<B, 3>>,
) -> Tensor<B, 3> {
let seq_length = batched_input.shape().dims[1];
let mut hidden_state = match state {
Some(state) => state,
None => Tensor::zeros([self.batch_size, seq_length, self.d_hidden]),
};
for t in 0..seq_length {
let indices = Tensor::arange(t..t + 1);
let input_t = batched_input
.clone()
.index_select(1, indices.clone())
.squeeze(1);
let hidden_t = hidden_state
.clone()
.index_select(1, indices.clone())
.squeeze(1);
// u(pdate)g(ate) tensors
let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t)
// r(eset)g(ate) tensors
let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate);
let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t)
let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
// n(ew)g(ate) tensor
let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t)
// calculate linear interpolation between previous hidden state and candidate state:
// g(t) * (1 - z(t)) + z(t) * hidden_t
let state_vector = candidate_state
.clone()
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
+ update_values.clone().mul(hidden_t);
hidden_state = hidden_state.index_assign(
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
state_vector.clone().unsqueeze(),
);
}
hidden_state
}
/// Helper function for performing weighted matrix product for a gate and adds
/// bias, if any.
///
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
/// Wx = weight matrix for the connection to input vector X
/// Wh = weight matrix for the connection to hidden state H
/// X = input vector
/// H = hidden state
/// b = bias terms
fn gate_product(
&self,
input: &Tensor<B, 2>,
hidden: &Tensor<B, 2>,
gate: &GateController<B>,
) -> Tensor<B, 2> {
let input_product = input.clone().matmul(gate.input_transform.weight.val());
let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
let input_bias = gate
.input_transform
.bias
.as_ref()
.map(|bias_param| bias_param.val());
let hidden_bias = gate
.hidden_transform
.bias
.as_ref()
.map(|bias_param| bias_param.val());
match (input_bias, hidden_bias) {
(Some(input_bias), Some(hidden_bias)) => {
input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
}
(Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
(None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
(None, None) => input_product + hidden_product,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::Data;
/// Test forward pass with simple input vector.
///
/// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
/// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
/// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
///
/// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
#[test]
fn tests_forward_single_input_single_feature() {
TestBackend::seed(0);
let config = GruConfig::new(1, 1, false, 1);
let mut gru = config.init::<TestBackend>();
fn create_gate_controller(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
) -> GateController<TestBackend> {
let record = LinearRecord {
weight: Param::from(Tensor::from_data(Data::from([[weights]]))),
bias: Some(Param::from(Tensor::from_data(Data::from([biases])))),
};
gate_controller::GateController::create_with_weights(
d_input,
d_output,
bias,
initializer,
record.clone(),
record,
)
}
gru.update_gate =
create_gate_controller(0.5, 0.0, 1, 1, false, Initializer::UniformDefault);
gru.reset_gate = create_gate_controller(0.6, 0.0, 1, 1, false, Initializer::UniformDefault);
gru.new_gate = create_gate_controller(0.7, 0.0, 1, 1, false, Initializer::UniformDefault);
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
let state = gru.forward(input, None);
let output = state.index_select(0, Tensor::arange(0..1)).squeeze(0);
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
}
}

View File

@ -1,14 +1,13 @@
use burn_tensor::activation;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::lstm::gate_controller;
use crate::nn::rnn::gate_controller;
use crate::nn::Initializer;
use crate::nn::LinearConfig;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::activation;
use super::gate_controller::GateController;
@ -25,7 +24,7 @@ pub struct LstmConfig {
/// a better choice. https://github.com/burn-rs/burn/issues/371
#[config(default = "Initializer::Uniform(0.0, 1.0)")]
pub initializer: Initializer,
/// The batch size
/// The batch size.
pub batch_size: usize,
}
@ -112,10 +111,10 @@ impl LstmConfig {
impl<B: Backend> Lstm<B> {
/// Applies the forward pass on the input tensor. This LSTM implementation
/// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`),
/// producing 3-dimensional tensors where the dimensions represent [batch_size, seq_length, hidden_size].
/// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size].
///
/// Parameters:
/// batched_input: The input tensor of shape [batch_size, seq_length, input_size].
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size].
/// state: An optional tuple of tensors representing the initial cell state and hidden state.
/// Each state tensor has shape [batch_size, hidden_size].
/// If no initial state is provided, these tensors are initialized to zeros.
@ -123,7 +122,7 @@ impl<B: Backend> Lstm<B> {
/// Returns:
/// A tuple of tensors, where the first tensor represents the cell states and
/// the second tensor represents the hidden states for each sequence element.
/// Both output tensors have the shape [batch_size, seq_length, hidden_size].
/// Both output tensors have the shape [batch_size, sequence_length, hidden_size].
pub fn forward(
&mut self,
batched_input: Tensor<B, 3>,
@ -257,12 +256,12 @@ mod tests {
.assert_in_range(0.0, 1.0);
}
/// Test forward pass with simple input vector
/// Test forward pass with simple input vector.
///
/// f_t = sigmoid(0.7*0 + 0.8*0) = 0.5
/// i_t = sigmoid(0.5*0.1 + 0.6*0) = sigmoid(0.05) = 0.5123725
/// o_t = sigmoid(1.1*0.1 + 1.2*0) = sigmoid(0.11) = 0.5274723
/// c_t = tanh(0.9*0.1 + 1.0*0) = tanh(0.09) = 0.0892937
/// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
/// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
/// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
/// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
/// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
/// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648

View File

@ -1,5 +1,5 @@
mod gate_controller;
#[allow(clippy::module_inception)]
pub mod gru;
pub mod lstm;
pub use gate_controller::*;