mirror of https://github.com/tracel-ai/burn.git
feature(tensor): Add Tensor::stack (#1004)
This commit is contained in:
parent
d0cb7205fa
commit
2fdf9a3591
|
@ -7,3 +7,4 @@ Cargo.lock
|
|||
|
||||
.idea
|
||||
.vscode
|
||||
.fleet
|
||||
|
|
|
@ -61,6 +61,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `Tensor::from_data_device(data, device)` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
| `Tensor::stack(tensors, dim)` | torch.stack(tensors, dim)` |
|
||||
|
||||
### Numeric Operations
|
||||
|
||||
|
|
|
@ -457,6 +457,18 @@ where
|
|||
))
|
||||
}
|
||||
|
||||
/// Concatenates all tensors into a new one along a new dimension.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If all tensors don't have the same shape.
|
||||
/// Given dimension is not with range of 0..=D2
|
||||
pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
|
||||
check!(TensorCheck::stack(&tensors, dim));
|
||||
let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
|
||||
Tensor::<B, D2, K>::cat(tensors, dim)
|
||||
}
|
||||
|
||||
/// Iterate over slices of tensors alongside a given dimension.
|
||||
///
|
||||
/// # Panics
|
||||
|
|
|
@ -304,6 +304,51 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn stack<B: Backend, const D: usize, K: BasicOps<B>>(
|
||||
tensors: &[Tensor<B, D, K>],
|
||||
dim: usize,
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if dim > D {
|
||||
check = check.register(
|
||||
"Stack",
|
||||
TensorError::new(
|
||||
"Can't stack tensors on a dim that exceeds the tensors dimension (inclusive)",
|
||||
)
|
||||
.details(format!(
|
||||
"Trying to concatenate tensors with {D} dimensions on axis {dim}."
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
if tensors.is_empty() {
|
||||
return check.register(
|
||||
"Stack",
|
||||
TensorError::new("Can't stack an empty list of tensors."),
|
||||
);
|
||||
}
|
||||
|
||||
let shape_reference = tensors.get(0).unwrap().shape();
|
||||
|
||||
for tensor in tensors {
|
||||
let shape = tensor.shape();
|
||||
|
||||
if shape_reference != shape {
|
||||
return check.register(
|
||||
"Stack",
|
||||
TensorError::new("Can't stack tensors with different shapes").details(format!(
|
||||
"Provided dimension ({}), tensors shapes: {:?}",
|
||||
dim,
|
||||
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn cat<B: Backend, const D: usize, K: BasicOps<B>>(
|
||||
tensors: &[Tensor<B, D, K>],
|
||||
dim: usize,
|
||||
|
|
|
@ -67,6 +67,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_stack!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_abs!();
|
||||
burn_tensor::testgen_squeeze!();
|
||||
|
|
|
@ -37,6 +37,7 @@ mod sin;
|
|||
mod slice;
|
||||
mod sqrt;
|
||||
mod squeeze;
|
||||
mod stack;
|
||||
mod sub;
|
||||
mod tanh;
|
||||
mod transpose;
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
#[burn_tensor_testgen::testgen(stack)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Bool, Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_stack_ops_2d_dim0() {
|
||||
let tensor_1: Tensor<TestBackend, 2> = Tensor::from_data([[1.0, 2.0, 3.0]]);
|
||||
let tensor_2: Tensor<TestBackend, 2> = Tensor::from_data([[4.0, 5.0, 6.0]]);
|
||||
|
||||
let output = Tensor::stack(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]);
|
||||
output.into_data().assert_approx_eq(&data_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_stack_ops_int() {
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 3]]);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data([[4, 5, 6]]);
|
||||
|
||||
let output = Tensor::stack(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
let data_expected = Data::from([[[1, 2, 3]], [[4, 5, 6]]]);
|
||||
assert_eq!(&output.into_data(), &data_expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_stack_ops_bool() {
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data([[false, true, true]]);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data([[true, true, false]]);
|
||||
|
||||
let output = Tensor::stack(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
let data_expected = Data::from([[[false, true, true]], [[true, true, false]]]);
|
||||
assert_eq!(&output.into_data(), &data_expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_stack_ops_2d_dim1() {
|
||||
let tensor_1: Tensor<TestBackend, 2> = Tensor::from_data([[1.0, 2.0, 3.0]]);
|
||||
let tensor_2: Tensor<TestBackend, 2> = Tensor::from_data([[4.0, 5.0, 6.0]]);
|
||||
|
||||
let output = Tensor::stack(vec![tensor_1, tensor_2], 1);
|
||||
|
||||
let data_expected = Data::from([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]);
|
||||
output.into_data().assert_approx_eq(&data_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_stack_ops_3d() {
|
||||
let tensor_1: Tensor<TestBackend, 3> =
|
||||
TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]);
|
||||
let tensor_2: Tensor<TestBackend, 3> =
|
||||
TestTensor::from_data([[[4.0, 5.0, 6.0]], [[4.1, 5.1, 6.1]]]);
|
||||
|
||||
let output = Tensor::stack(vec![tensor_1, tensor_2], 0);
|
||||
|
||||
let data_expected = Data::from([
|
||||
[[[1.0000, 2.0000, 3.0000]], [[1.1000, 2.1000, 3.1000]]],
|
||||
[[[4.0000, 5.0000, 6.0000]], [[4.1000, 5.1000, 6.1000]]],
|
||||
]);
|
||||
output.into_data().assert_approx_eq(&data_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_dimensions_are_not_the_same() {
|
||||
let tensor_1: Tensor<TestBackend, 2> =
|
||||
Tensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
|
||||
let tensor_2: Tensor<TestBackend, 2> = Tensor::from_data([[4.0, 5.0]]);
|
||||
|
||||
let output: Tensor<TestBackend, 3> = Tensor::stack(vec![tensor_1, tensor_2], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_list_of_vectors_is_empty() {
|
||||
let tensors: Vec<Tensor<TestBackend, 2>> = vec![];
|
||||
let output: Tensor<TestBackend, 3> = TestTensor::stack(tensors, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_stack_exceeds_dimension() {
|
||||
let tensor_1: Tensor<TestBackend, 3> =
|
||||
Tensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]);
|
||||
let tensor_2: Tensor<TestBackend, 3> = Tensor::from_data([[[4.0, 5.0, 6.0]]]);
|
||||
|
||||
let output: Tensor<TestBackend, 4> = TestTensor::stack(vec![tensor_1, tensor_2], 3);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue