mirror of https://github.com/tracel-ai/burn.git
feature(tensor): Add unsqueeze_dim helper (#966)
This commit is contained in:
parent
20e9066b57
commit
49e16b6834
|
@ -46,6 +46,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
|
|
|
@ -257,6 +257,40 @@ where
|
|||
self.reshape(shape)
|
||||
}
|
||||
|
||||
/// Creates a new tensor with a dimension of size one inserted at the specified position.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Shape};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
|
||||
/// let tensor: Tensor<B, 3> = tensor.unsqueeze_dim(1);
|
||||
/// println!("{:?}", tensor.shape());
|
||||
/// // Shape { dims: [3, 1, 3] }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
|
||||
check!(TensorCheck::unsqueeze_dim::<{ D }>(dim));
|
||||
|
||||
let mut dims = [1; D2];
|
||||
let shape = self.shape();
|
||||
|
||||
dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
|
||||
|
||||
if dim < D {
|
||||
dims[dim] = 1;
|
||||
dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
|
||||
} else {
|
||||
dims[dim] = 1;
|
||||
}
|
||||
|
||||
let shape = Shape::new(dims);
|
||||
self.reshape(shape)
|
||||
}
|
||||
|
||||
/// Returns a tensor containing the elements selected from the given ranges.
|
||||
///
|
||||
/// # Panics
|
||||
|
|
|
@ -198,6 +198,21 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn unsqueeze_dim<const D: usize>(dim: usize) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
if dim > D {
|
||||
check = check.register(
|
||||
"Unsqueeze",
|
||||
TensorError::new(format!(
|
||||
"Can't unsqueeze at dimension {}, exceeds tensor dimensions (D={})",
|
||||
dim, D
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
|
|
|
@ -34,4 +34,39 @@ mod tests {
|
|||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
|
||||
}
|
||||
|
||||
/// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor.
|
||||
#[test]
|
||||
fn should_unsqueeze_dim() {
|
||||
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 4, 1]));
|
||||
let unsqueezed_tensor: Tensor<TestBackend, 4> = tensor.unsqueeze_dim(1);
|
||||
let expected_shape = Shape::new([2, 1, 4, 1]);
|
||||
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor.
|
||||
#[test]
|
||||
fn should_unsqueeze_dim_first() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(0);
|
||||
let expected_shape = Shape::new([1, 2, 3, 4, 5]);
|
||||
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor.
|
||||
#[test]
|
||||
fn should_unsqueeze_dim_last() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([5, 4, 3, 2]));
|
||||
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(4);
|
||||
let expected_shape = Shape::new([5, 4, 3, 2, 1]);
|
||||
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
|
||||
}
|
||||
|
||||
/// Test if the function panics when the unsqueezed dimension is out of bounds.
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_unsqueeze_dim_panic() {
|
||||
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
|
||||
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(5);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue