mirror of https://github.com/tracel-ai/burn.git
Feat/lstm (#370)
This commit is contained in:
parent
bff752b1a8
commit
8a88a868ee
|
@ -37,8 +37,8 @@ pub struct LinearConfig {
|
|||
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Linear<B: Backend> {
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
bias: Option<Param<Tensor<B, 1>>>,
|
||||
pub(crate) weight: Param<Tensor<B, 2>>,
|
||||
pub(crate) bias: Option<Param<Tensor<B, 1>>>,
|
||||
}
|
||||
|
||||
impl LinearConfig {
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::module::Module;
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::Linear;
|
||||
use crate::nn::LinearConfig;
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
/// A GateController represents a gate in an LSTM cell. An
|
||||
/// LSTM cell generally contains three gates: an input gate,
|
||||
/// forget gate, and cell gate.
|
||||
///
|
||||
/// An Lstm gate is modeled as two linear transformations.
|
||||
/// The results of these transformations are used to calculate
|
||||
/// the gate's output.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GateController<B: Backend> {
|
||||
/// Represents the affine transformation applied to input vector
|
||||
pub(crate) input_transform: Linear<B>,
|
||||
/// Represents the affine transformation applied to the hidden state
|
||||
pub(crate) hidden_transform: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> GateController<B> {
|
||||
/// Initialize a new [gate_controller](GateController) module.
|
||||
pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self {
|
||||
Self {
|
||||
input_transform: LinearConfig {
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer: initializer.clone(),
|
||||
}
|
||||
.init(),
|
||||
hidden_transform: LinearConfig {
|
||||
d_input: d_output,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
}
|
||||
.init(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord).
|
||||
pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord<B>) -> Self {
|
||||
let l1 = LinearConfig::init_with(linear_config, record.input_transform);
|
||||
let l2 = LinearConfig::init_with(linear_config, record.hidden_transform);
|
||||
|
||||
Self {
|
||||
input_transform: l1,
|
||||
hidden_transform: l2,
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to initialize a gate controller with known weight layers,
|
||||
/// allowing for predictable behavior. Used only for testing in
|
||||
/// lstm.
|
||||
#[cfg(test)]
|
||||
pub fn create_with_weights(
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
input_record: crate::nn::LinearRecord<B>,
|
||||
hidden_record: crate::nn::LinearRecord<B>,
|
||||
) -> Self {
|
||||
let l1 = LinearConfig {
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer: initializer.clone(),
|
||||
}
|
||||
.init_with(input_record);
|
||||
let l2 = LinearConfig {
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
}
|
||||
.init_with(hidden_record);
|
||||
Self {
|
||||
input_transform: l1,
|
||||
hidden_transform: l2,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
use burn_tensor::activation;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::nn::lstm::gate_controller;
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::LinearConfig;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use super::gate_controller::GateController;
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct LstmConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the hidden state.
|
||||
pub d_hidden: usize,
|
||||
/// If a bias should be applied during the Lstm transformation.
|
||||
pub bias: bool,
|
||||
/// Lstm initializer
|
||||
/// TODO: Make default Xavier initialization, which should be
|
||||
/// a better choice. https://github.com/burn-rs/burn/issues/371
|
||||
#[config(default = "Initializer::Uniform(0.0, 1.0)")]
|
||||
pub initializer: Initializer,
|
||||
/// The batch size
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Lstm<B: Backend> {
|
||||
input_gate: GateController<B>,
|
||||
forget_gate: GateController<B>,
|
||||
output_gate: GateController<B>,
|
||||
cell_gate: GateController<B>,
|
||||
batch_size: usize,
|
||||
d_hidden: usize,
|
||||
}
|
||||
|
||||
impl LstmConfig {
|
||||
/// Initialize a new [lstm](Lstm) module.
|
||||
pub fn init<B: Backend>(&self) -> Lstm<B> {
|
||||
let d_output = self.d_hidden;
|
||||
|
||||
let input_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
let forget_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
let output_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
let cell_gate = gate_controller::GateController::new(
|
||||
self.d_input,
|
||||
d_output,
|
||||
self.bias,
|
||||
self.initializer.clone(),
|
||||
);
|
||||
|
||||
Lstm {
|
||||
input_gate,
|
||||
forget_gate,
|
||||
output_gate,
|
||||
cell_gate,
|
||||
batch_size: self.batch_size,
|
||||
d_hidden: self.d_hidden,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a new [lstm](lstm) module with a [record](LstmRecord).
|
||||
pub fn init_with<B: Backend>(&self, record: LstmRecord<B>) -> Lstm<B> {
|
||||
let linear_config = LinearConfig {
|
||||
d_input: self.d_input,
|
||||
d_output: self.d_hidden,
|
||||
bias: self.bias,
|
||||
initializer: self.initializer.clone(),
|
||||
};
|
||||
|
||||
Lstm {
|
||||
input_gate: gate_controller::GateController::new_with(
|
||||
&linear_config,
|
||||
record.input_gate,
|
||||
),
|
||||
forget_gate: gate_controller::GateController::new_with(
|
||||
&linear_config,
|
||||
record.forget_gate,
|
||||
),
|
||||
output_gate: gate_controller::GateController::new_with(
|
||||
&linear_config,
|
||||
record.output_gate,
|
||||
),
|
||||
cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate),
|
||||
batch_size: self.batch_size,
|
||||
d_hidden: self.d_hidden,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Lstm<B> {
|
||||
/// Applies the forward pass on the input tensor. This LSTM implementation
|
||||
/// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`),
|
||||
/// producing 3-dimensional tensors where the dimensions represent [batch_size, seq_length, hidden_size].
|
||||
///
|
||||
/// Parameters:
|
||||
/// batched_input: The input tensor of shape [batch_size, seq_length, input_size].
|
||||
/// state: An optional tuple of tensors representing the initial cell state and hidden state.
|
||||
/// Each state tensor has shape [batch_size, hidden_size].
|
||||
/// If no initial state is provided, these tensors are initialized to zeros.
|
||||
///
|
||||
/// Returns:
|
||||
/// A tuple of tensors, where the first tensor represents the cell states and
|
||||
/// the second tensor represents the hidden states for each sequence element.
|
||||
/// Both output tensors have the shape [batch_size, seq_length, hidden_size].
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
|
||||
) -> (Tensor<B, 3>, Tensor<B, 3>) {
|
||||
let seq_length = batched_input.shape().dims[1];
|
||||
let mut batched_cell_state = Tensor::zeros([self.batch_size, seq_length, self.d_hidden]);
|
||||
let mut batched_hidden_state = Tensor::zeros([self.batch_size, seq_length, self.d_hidden]);
|
||||
|
||||
let (mut cell_state, mut hidden_state) = match state {
|
||||
Some((cell_state, hidden_state)) => (cell_state, hidden_state),
|
||||
None => (
|
||||
Tensor::zeros([self.batch_size, self.d_hidden]),
|
||||
Tensor::zeros([self.batch_size, self.d_hidden]),
|
||||
),
|
||||
};
|
||||
|
||||
for t in 0..seq_length {
|
||||
let indices = Tensor::arange(t..t + 1);
|
||||
let input_t = batched_input.clone().index_select(1, indices).squeeze(1);
|
||||
// f(orget)g(ate) tensors
|
||||
let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate);
|
||||
let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state
|
||||
|
||||
// i(nput)g(ate) tensors
|
||||
let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate);
|
||||
let add_values = activation::sigmoid(biased_ig_input_sum);
|
||||
|
||||
// o(utput)g(ate) tensors
|
||||
let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate);
|
||||
let output_values = activation::sigmoid(biased_og_input_sum);
|
||||
|
||||
// c(ell)g(ate) tensors
|
||||
let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate);
|
||||
let candidate_cell_values = biased_cg_input_sum.tanh();
|
||||
|
||||
cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
|
||||
hidden_state = output_values * cell_state.clone().tanh();
|
||||
|
||||
// store the state for this timestep
|
||||
batched_cell_state = batched_cell_state.index_assign(
|
||||
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
cell_state.clone().unsqueeze(),
|
||||
);
|
||||
batched_hidden_state = batched_hidden_state.index_assign(
|
||||
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
hidden_state.clone().unsqueeze(),
|
||||
);
|
||||
}
|
||||
|
||||
(batched_cell_state, batched_hidden_state)
|
||||
}
|
||||
|
||||
/// Helper function for performing weighted matrix product for a gate and adds
|
||||
/// bias, if any.
|
||||
///
|
||||
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
|
||||
/// Wx = weight matrix for the connection to input vector X
|
||||
/// Wh = weight matrix for the connection to hidden state H
|
||||
/// X = input vector
|
||||
/// H = hidden state
|
||||
/// b = bias terms
|
||||
fn gate_product(
|
||||
&self,
|
||||
input: &Tensor<B, 2>,
|
||||
hidden: &Tensor<B, 2>,
|
||||
gate: &GateController<B>,
|
||||
) -> Tensor<B, 2> {
|
||||
let input_product = input.clone().matmul(gate.input_transform.weight.val());
|
||||
let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
|
||||
|
||||
let input_bias = gate
|
||||
.input_transform
|
||||
.bias
|
||||
.as_ref()
|
||||
.map(|bias_param| bias_param.val());
|
||||
let hidden_bias = gate
|
||||
.hidden_transform
|
||||
.bias
|
||||
.as_ref()
|
||||
.map(|bias_param| bias_param.val());
|
||||
|
||||
match (input_bias, hidden_bias) {
|
||||
(Some(input_bias), Some(hidden_bias)) => {
|
||||
input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
|
||||
}
|
||||
(Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
|
||||
(None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
|
||||
(None, None) => input_product + hidden_product,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let config = LstmConfig::new(5, 5, false, 2);
|
||||
let lstm = config.init::<TestBackend>();
|
||||
|
||||
lstm.input_gate
|
||||
.input_transform
|
||||
.weight
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_in_range(0.0, 1.0);
|
||||
lstm.forget_gate
|
||||
.input_transform
|
||||
.weight
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_in_range(0.0, 1.0);
|
||||
lstm.output_gate
|
||||
.input_transform
|
||||
.weight
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_in_range(0.0, 1.0);
|
||||
lstm.cell_gate
|
||||
.input_transform
|
||||
.weight
|
||||
.val()
|
||||
.to_data()
|
||||
.assert_in_range(0.0, 1.0);
|
||||
}
|
||||
|
||||
/// Test forward pass with simple input vector
|
||||
///
|
||||
/// f_t = sigmoid(0.7*0 + 0.8*0) = 0.5
|
||||
/// i_t = sigmoid(0.5*0.1 + 0.6*0) = sigmoid(0.05) = 0.5123725
|
||||
/// o_t = sigmoid(1.1*0.1 + 1.2*0) = sigmoid(0.11) = 0.5274723
|
||||
/// c_t = tanh(0.9*0.1 + 1.0*0) = tanh(0.09) = 0.0892937
|
||||
|
||||
/// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
|
||||
/// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
|
||||
#[test]
|
||||
fn test_forward_single_input_single_feature() {
|
||||
TestBackend::seed(0);
|
||||
let config = LstmConfig::new(1, 1, false, 1);
|
||||
let mut lstm = config.init::<TestBackend>();
|
||||
|
||||
fn create_gate_controller(
|
||||
weights: f32,
|
||||
biases: f32,
|
||||
d_input: usize,
|
||||
d_output: usize,
|
||||
bias: bool,
|
||||
initializer: Initializer,
|
||||
) -> GateController<TestBackend> {
|
||||
let record = LinearRecord {
|
||||
weight: Param::from(Tensor::from_data(Data::from([[weights]]))),
|
||||
bias: Some(Param::from(Tensor::from_data(Data::from([biases])))),
|
||||
};
|
||||
gate_controller::GateController::create_with_weights(
|
||||
d_input,
|
||||
d_output,
|
||||
bias,
|
||||
initializer,
|
||||
record.clone(),
|
||||
record,
|
||||
)
|
||||
}
|
||||
|
||||
lstm.input_gate =
|
||||
create_gate_controller(0.5, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
lstm.forget_gate =
|
||||
create_gate_controller(0.7, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
lstm.cell_gate = create_gate_controller(0.9, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
lstm.output_gate =
|
||||
create_gate_controller(1.1, 0.0, 1, 1, false, Initializer::UniformDefault);
|
||||
|
||||
// single timestep with single feature
|
||||
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
|
||||
|
||||
let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None);
|
||||
let cell_state = cell_state_batch
|
||||
.index_select(0, Tensor::arange(0..1))
|
||||
.squeeze(0);
|
||||
let hidden_state = hidden_state_batch
|
||||
.index_select(0, Tensor::arange(0..1))
|
||||
.squeeze(0);
|
||||
cell_state
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[0.046]]), 3);
|
||||
hidden_state
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[0.024]]), 3)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
mod gate_controller;
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod lstm;
|
||||
|
||||
pub use gate_controller::*;
|
||||
pub use lstm::*;
|
|
@ -10,6 +10,7 @@ mod embedding;
|
|||
mod gelu;
|
||||
mod initializer;
|
||||
mod linear;
|
||||
mod lstm;
|
||||
mod norm;
|
||||
mod relu;
|
||||
|
||||
|
@ -18,5 +19,6 @@ pub use embedding::*;
|
|||
pub use gelu::*;
|
||||
pub use initializer::*;
|
||||
pub use linear::*;
|
||||
pub use lstm::*;
|
||||
pub use norm::*;
|
||||
pub use relu::*;
|
||||
|
|
|
@ -110,6 +110,50 @@ where
|
|||
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
|
||||
}
|
||||
|
||||
/// Squeeze the tensor along the given dimension, removing the specified dimension
|
||||
/// of size one, and effectively reducing the rank of the tensor by one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `dim`: The dimension to be squeezed.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// - 'D2': The resulting number of dimensions in the squeezed tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `Tensor<B, D2, K>` instance with the specified dimenension removed.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
///
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Shape};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 1, 4]));
|
||||
///
|
||||
/// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1
|
||||
/// let squeezed_tensor: Tensor::<B, 2> = tensor.squeeze(1);
|
||||
///
|
||||
/// // Resulting tensor will have dimensions (2, 4)
|
||||
/// println!("{:?}", squeezed_tensor.shape());
|
||||
/// }
|
||||
/// ```
|
||||
pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
|
||||
check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
|
||||
|
||||
let current_dims = self.shape().dims;
|
||||
let mut new_dims: [usize; D2] = [0; D2];
|
||||
|
||||
new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
|
||||
new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
|
||||
|
||||
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
|
||||
}
|
||||
|
||||
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
|
||||
///
|
||||
/// # Panics
|
||||
|
|
|
@ -133,6 +133,23 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn squeeze<const D2: usize>(dim: usize, tensor_dims: &[usize]) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
// This should actually be to check that the dimension to squeeze
|
||||
// has a size of 1
|
||||
if tensor_dims[dim] != 1 {
|
||||
check = check.register(
|
||||
"Squeeze",
|
||||
TensorError::new(format!(
|
||||
"Can't squeeze dimension {} because its size is not 1",
|
||||
dim
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
|
||||
let mut check = Self::Ok;
|
||||
if D2 < D1 {
|
||||
|
|
|
@ -47,6 +47,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_flatten!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_squeeze!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
|
|
|
@ -22,6 +22,7 @@ mod repeat;
|
|||
mod reshape;
|
||||
mod sin;
|
||||
mod sqrt;
|
||||
mod squeeze;
|
||||
mod sub;
|
||||
mod tanh;
|
||||
mod transpose;
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#[burn_tensor_testgen::testgen(squeeze)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
/// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor.
|
||||
#[test]
|
||||
fn should_squeeze() {
|
||||
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 1, 4]));
|
||||
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze(1);
|
||||
let expected_shape = Shape::new([2, 4]);
|
||||
assert_eq!(squeezed_tensor.shape(), expected_shape);
|
||||
}
|
||||
/// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor.
|
||||
#[test]
|
||||
fn should_squeeze_first() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([1, 3, 4, 5]));
|
||||
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(0);
|
||||
let expected_shape = Shape::new([3, 4, 5]);
|
||||
assert_eq!(squeezed_tensor.shape(), expected_shape);
|
||||
}
|
||||
/// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor.
|
||||
#[test]
|
||||
fn should_squeeze_last() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 1]));
|
||||
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(3);
|
||||
let expected_shape = Shape::new([2, 3, 4]);
|
||||
assert_eq!(squeezed_tensor.shape(), expected_shape);
|
||||
}
|
||||
/// Test if the function panics when the squeezed dimension is not of size 1.
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_squeeze_panic() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue