From 49e16b68348a670e621ebd65ddc57ac75a814554 Mon Sep 17 00:00:00 2001 From: David Chavez Date: Mon, 20 Nov 2023 16:29:40 +0100 Subject: [PATCH] feature(tensor): Add unsqueeze_dim helper (#966) --- burn-book/src/building-blocks/tensor.md | 1 + burn-tensor/src/tensor/api/base.rs | 34 ++++++++++++++++++++++++ burn-tensor/src/tensor/api/check.rs | 15 +++++++++++ burn-tensor/src/tests/ops/squeeze.rs | 35 +++++++++++++++++++++++++ 4 files changed, 85 insertions(+) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 270a2477c..d996953cc 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -46,6 +46,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | | `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | | `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | +| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | | `tensor.slice(ranges)` | `tensor[(*ranges,)]` | | `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | | `tensor.device()` | `tensor.device` | diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index 1f20fe017..7da58adc5 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -257,6 +257,40 @@ where self.reshape(shape) } + /// Creates a new tensor with a dimension of size one inserted at the specified position. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([3, 3])); + /// let tensor: Tensor = tensor.unsqueeze_dim(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [3, 1, 3] } + /// } + /// ``` + pub fn unsqueeze_dim(self, dim: usize) -> Tensor { + check!(TensorCheck::unsqueeze_dim::<{ D }>(dim)); + + let mut dims = [1; D2]; + let shape = self.shape(); + + dims[0..dim].copy_from_slice(&shape.dims[0..dim]); + + if dim < D { + dims[dim] = 1; + dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]); + } else { + dims[dim] = 1; + } + + let shape = Shape::new(dims); + self.reshape(shape) + } + /// Returns a tensor containing the elements selected from the given ranges. /// /// # Panics diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index da237ee1a..0cecc57f9 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -198,6 +198,21 @@ impl TensorCheck { check } + pub(crate) fn unsqueeze_dim(dim: usize) -> Self { + let mut check = Self::Ok; + if dim > D { + check = check.register( + "Unsqueeze", + TensorError::new(format!( + "Can't unsqueeze at dimension {}, exceeds tensor dimensions (D={})", + dim, D + )), + ); + } + + check + } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; diff --git a/burn-tensor/src/tests/ops/squeeze.rs b/burn-tensor/src/tests/ops/squeeze.rs index d8de064bd..c2f23245c 100644 --- a/burn-tensor/src/tests/ops/squeeze.rs +++ b/burn-tensor/src/tests/ops/squeeze.rs @@ -34,4 +34,39 @@ mod tests { let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); let squeezed_tensor: Tensor = tensor.squeeze(2); } + + /// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor. + #[test] + fn should_unsqueeze_dim() { + let tensor = Tensor::::ones(Shape::new([2, 4, 1])); + let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(1); + let expected_shape = Shape::new([2, 1, 4, 1]); + assert_eq!(unsqueezed_tensor.shape(), expected_shape); + } + + /// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor. + #[test] + fn should_unsqueeze_dim_first() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(0); + let expected_shape = Shape::new([1, 2, 3, 4, 5]); + assert_eq!(unsqueezed_tensor.shape(), expected_shape); + } + + /// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor. + #[test] + fn should_unsqueeze_dim_last() { + let tensor = Tensor::::ones(Shape::new([5, 4, 3, 2])); + let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(4); + let expected_shape = Shape::new([5, 4, 3, 2, 1]); + assert_eq!(unsqueezed_tensor.shape(), expected_shape); + } + + /// Test if the function panics when the unsqueezed dimension is out of bounds. + #[test] + #[should_panic] + fn should_unsqueeze_dim_panic() { + let tensor = Tensor::::ones(Shape::new([2, 3, 4, 5])); + let unsqueezed_tensor: Tensor = tensor.unsqueeze_dim(5); + } }