feature(tensor): Add unsqueeze_dim helper (#966)

This commit is contained in:
David Chavez 2023-11-20 16:29:40 +01:00 committed by GitHub
parent 20e9066b57
commit 49e16b6834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 0 deletions

View File

@ -46,6 +46,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.device()` | `tensor.device` |

View File

@ -257,6 +257,40 @@ where
self.reshape(shape)
}
/// Creates a new tensor with a dimension of size one inserted at the specified position.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
/// let tensor: Tensor<B, 3> = tensor.unsqueeze_dim(1);
/// println!("{:?}", tensor.shape());
/// // Shape { dims: [3, 1, 3] }
/// }
/// ```
pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
check!(TensorCheck::unsqueeze_dim::<{ D }>(dim));
let mut dims = [1; D2];
let shape = self.shape();
dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
if dim < D {
dims[dim] = 1;
dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
} else {
dims[dim] = 1;
}
let shape = Shape::new(dims);
self.reshape(shape)
}
/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Panics

View File

@ -198,6 +198,21 @@ impl TensorCheck {
check
}
pub(crate) fn unsqueeze_dim<const D: usize>(dim: usize) -> Self {
let mut check = Self::Ok;
if dim > D {
check = check.register(
"Unsqueeze",
TensorError::new(format!(
"Can't unsqueeze at dimension {}, exceeds tensor dimensions (D={})",
dim, D
)),
);
}
check
}
pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
let mut check = Self::Ok;

View File

@ -34,4 +34,39 @@ mod tests {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
}
/// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor.
#[test]
fn should_unsqueeze_dim() {
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 4, 1]));
let unsqueezed_tensor: Tensor<TestBackend, 4> = tensor.unsqueeze_dim(1);
let expected_shape = Shape::new([2, 1, 4, 1]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}
/// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor.
#[test]
fn should_unsqueeze_dim_first() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(0);
let expected_shape = Shape::new([1, 2, 3, 4, 5]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}
/// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor.
#[test]
fn should_unsqueeze_dim_last() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([5, 4, 3, 2]));
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(4);
let expected_shape = Shape::new([5, 4, 3, 2, 1]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}
/// Test if the function panics when the unsqueezed dimension is out of bounds.
#[test]
#[should_panic]
fn should_unsqueeze_dim_panic() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(5);
}
}