mirror of https://github.com/tracel-ai/burn.git
Feat/gru (#393)
This commit is contained in:
parent
8c9802c363
commit
f8cd38c071
|
@ -10,15 +10,15 @@ mod embedding;
|
|||
mod gelu;
|
||||
mod initializer;
|
||||
mod linear;
|
||||
mod lstm;
|
||||
mod norm;
|
||||
mod relu;
|
||||
mod rnn;
|
||||
|
||||
pub use dropout::*;
|
||||
pub use embedding::*;
|
||||
pub use gelu::*;
|
||||
pub use initializer::*;
|
||||
pub use linear::*;
|
||||
pub use lstm::*;
|
||||
pub use norm::*;
|
||||
pub use relu::*;
|
||||
pub use rnn::*;
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::nn::rnn::gate_controller;
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::LinearConfig;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::activation;
|
||||
|
||||
use super::gate_controller::GateController;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct GruConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If a bias should be applied during the Gru transformation.
|
||||
pub bias: bool,
|
||||
/// Gru initializer
|
||||
/// TODO: Make default Xavier initialization. https://github.com/burn-rs/burn/issues/371
|
||||
#[config(default = "Initializer::Uniform(0.0, 1.0)")]
|
||||
pub initializer: Initializer,
|
||||
/// The batch size.
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
/// The Gru module. This implementation is for a unidirectional, stateless, Gru.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Gru<B: Backend> {
|
||||
update_gate: GateController<B>,
|
||||
reset_gate: GateController<B>,
|
||||
new_gate: GateController<B>,
|
||||
batch_size: usize,
|
||||
d_hidden: usize,
|
||||
}
|
||||
|
||||
impl GruConfig {
|
||||
/// Initialize a new [gru](Gru) module.
|
||||
pub fn init<B: Backend>(&self) -> Gru<B> {
|
||||
let d_output = self.d_hidden;
|
||||
|
||||
let update_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
let reset_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
let new_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
|
||||
Gru {
|
||||
update_gate,
|
||||
reset_gate,
|
||||
new_gate,
|
||||
batch_size: self.batch_size,
|
||||
d_hidden: self.d_hidden,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a new [gru](Gru) module.
|
||||
pub fn init_with<B: Backend>(self, record: GruRecord<B>) -> Gru<B> {
|
||||
let linear_config = LinearConfig {
|
||||
d_input: self.d_input,
|
||||
d_output: self.d_hidden,
|
||||
bias: self.bias,
|
||||
initializer: self.initializer.clone(),
|
||||
};
|
||||
|
||||
Gru {
|
||||
update_gate: gate_controller::GateController::new_with(
|
||||
&linear_config,
|
||||
record.update_gate,
|
||||
),
|
||||
reset_gate: gate_controller::GateController::new_with(
|
||||
&linear_config,
|
||||
record.reset_gate,
|
||||
),
|
||||
new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate),
|
||||
batch_size: self.batch_size,
|
||||
d_hidden: self.d_hidden,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Gru<B> {
|
||||
/// Applies the forward pass on the input tensor. This GRU implementation
|
||||
/// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size].
|
||||
///
|
||||
/// Parameters:
|
||||
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size].
|
||||
/// state: An optional tensor representing an initial cell state with the same dimensions
|
||||
/// as batched_input. If none is provided, one will be generated.
|
||||
///
|
||||
/// Returns:
|
||||
/// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size].
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
state: Option<Tensor<B, 3>>,
|
||||
) -> Tensor<B, 3> {
|
||||
let seq_length = batched_input.shape().dims[1];
|
||||
|
||||
let mut hidden_state = match state {
|
||||
Some(state) => state,
|
||||
None => Tensor::zeros([self.batch_size, seq_length, self.d_hidden]),
|
||||
};
|
||||
|
||||
for t in 0..seq_length {
|
||||
let indices = Tensor::arange(t..t + 1);
|
||||
let input_t = batched_input
|
||||
.clone()
|
||||
.index_select(1, indices.clone())
|
||||
.squeeze(1);
|
||||
let hidden_t = hidden_state
|
||||
.clone()
|
||||
.index_select(1, indices.clone())
|
||||
.squeeze(1);
|
||||
|
||||
// u(pdate)g(ate) tensors
|
||||
let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
|
||||
let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t)
|
||||
|
||||
// r(eset)g(ate) tensors
|
||||
let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate);
|
||||
let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t)
|
||||
let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
|
||||
|
||||
// n(ew)g(ate) tensor
|
||||
let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
|
||||
let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t)
|
||||
|
||||
// calculate linear interpolation between previous hidden state and candidate state:
|
||||
// g(t) * (1 - z(t)) + z(t) * hidden_t
|
||||
let state_vector = candidate_state
|
||||
.clone()
|
||||
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
|
||||
+ update_values.clone().mul(hidden_t);
|
||||
|
||||
hidden_state = hidden_state.index_assign(
|
||||
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
state_vector.clone().unsqueeze(),
|
||||
);
|
||||
}
|
||||
|
||||
hidden_state
|
||||
}
|
||||
|
||||
/// Helper function for performing weighted matrix product for a gate and adds
|
||||
/// bias, if any.
|
||||
///
|
||||
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
|
||||
/// Wx = weight matrix for the connection to input vector X
|
||||
/// Wh = weight matrix for the connection to hidden state H
|
||||
/// X = input vector
|
||||
/// H = hidden state
|
||||
/// b = bias terms
|
||||
fn gate_product(
|
||||
&self,
|
||||
input: &Tensor<B, 2>,
|
||||
hidden: &Tensor<B, 2>,
|
||||
gate: &GateController<B>,
|
||||
) -> Tensor<B, 2> {
|
||||
let input_product = input.clone().matmul(gate.input_transform.weight.val());
|
||||
let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
|
||||
|
||||
let input_bias = gate
|
||||
.input_transform
|
||||
.bias
|
||||
.as_ref()
|
||||
.map(|bias_param| bias_param.val());
|
||||
let hidden_bias = gate
|
||||
.hidden_transform
|
||||
.bias
|
||||
.as_ref()
|
||||
.map(|bias_param| bias_param.val());
|
||||
|
||||
match (input_bias, hidden_bias) {
|
||||
(Some(input_bias), Some(hidden_bias)) => {
|
||||
input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
|
||||
}
|
||||
(Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
|
||||
(None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
|
||||
(None, None) => input_product + hidden_product,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::Data;
|
||||
|
||||
/// Test forward pass with simple input vector.
|
||||
///
|
||||
/// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
|
||||
/// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
|
||||
/// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
|
||||
///
|
||||
/// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
|
||||
#[test]
|
||||
fn tests_forward_single_input_single_feature() {
|
||||
TestBackend::seed(0);
|
||||
let config = GruConfig::new(1, 1, false, 1);
|
||||
let mut gru = config.init::<TestBackend>();
|
||||
|
||||
fn create_gate_controller(
|
||||
weights: f32,
|
||||
biases: f32,
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
) -> GateController<TestBackend> {
|
||||
let record = LinearRecord {
|
||||
weight: Param::from(Tensor::from_data(Data::from([[weights]]))),
|
||||
bias: Some(Param::from(Tensor::from_data(Data::from([biases])))),
|
||||
};
|
||||
gate_controller::GateController::create_with_weights(
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
record.clone(),
|
||||
record,
|
||||
)
|
||||
}
|
||||
|
||||
gru.update_gate =
|
||||
create_gate_controller(0.5, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
gru.reset_gate = create_gate_controller(0.6, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
gru.new_gate = create_gate_controller(0.7, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
|
||||
|
||||
let state = gru.forward(input, None);
|
||||
|
||||
let output = state.index_select(0, Tensor::arange(0..1)).squeeze(0);
|
||||
|
||||
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
|
||||
}
|
||||
}
|
|
@ -1,14 +1,13 @@
|
|||
use burn_tensor::activation;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::nn::lstm::gate_controller;
|
||||
use crate::nn::rnn::gate_controller;
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::LinearConfig;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::activation;
|
||||
|
||||
use super::gate_controller::GateController;
|
||||
|
||||
|
@ -25,7 +24,7 @@ pub struct LstmConfig {
|
|||
/// a better choice. https://github.com/burn-rs/burn/issues/371
|
||||
#[config(default = "Initializer::Uniform(0.0, 1.0)")]
|
||||
pub initializer: Initializer,
|
||||
/// The batch size
|
||||
/// The batch size.
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
|
@ -112,10 +111,10 @@ impl LstmConfig {
|
|||
impl<B: Backend> Lstm<B> {
|
||||
/// Applies the forward pass on the input tensor. This LSTM implementation
|
||||
/// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`),
|
||||
/// producing 3-dimensional tensors where the dimensions represent [batch_size, seq_length, hidden_size].
|
||||
/// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size].
|
||||
///
|
||||
/// Parameters:
|
||||
/// batched_input: The input tensor of shape [batch_size, seq_length, input_size].
|
||||
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size].
|
||||
/// state: An optional tuple of tensors representing the initial cell state and hidden state.
|
||||
/// Each state tensor has shape [batch_size, hidden_size].
|
||||
/// If no initial state is provided, these tensors are initialized to zeros.
|
||||
|
@ -123,7 +122,7 @@ impl<B: Backend> Lstm<B> {
|
|||
/// Returns:
|
||||
/// A tuple of tensors, where the first tensor represents the cell states and
|
||||
/// the second tensor represents the hidden states for each sequence element.
|
||||
/// Both output tensors have the shape [batch_size, seq_length, hidden_size].
|
||||
/// Both output tensors have the shape [batch_size, sequence_length, hidden_size].
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
|
@ -257,12 +256,12 @@ mod tests {
|
|||
.assert_in_range(0.0, 1.0);
|
||||
}
|
||||
|
||||
/// Test forward pass with simple input vector
|
||||
/// Test forward pass with simple input vector.
|
||||
///
|
||||
/// f_t = sigmoid(0.7*0 + 0.8*0) = 0.5
|
||||
/// i_t = sigmoid(0.5*0.1 + 0.6*0) = sigmoid(0.05) = 0.5123725
|
||||
/// o_t = sigmoid(1.1*0.1 + 1.2*0) = sigmoid(0.11) = 0.5274723
|
||||
/// c_t = tanh(0.9*0.1 + 1.0*0) = tanh(0.09) = 0.0892937
|
||||
/// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
|
||||
/// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
|
||||
/// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
|
||||
/// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
|
||||
|
||||
/// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
|
||||
/// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
|
|
@ -1,5 +1,5 @@
|
|||
mod gate_controller;
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod gru;
|
||||
pub mod lstm;
|
||||
|
||||
pub use gate_controller::*;
|
Loading…
Reference in New Issue