Move unsqueeze op to the tensor's base (#261)

This commit is contained in:
Dilshod Tadjibaev 2023-04-01 13:20:48 -05:00 committed by GitHub
parent 32d38bebc3
commit 4e9e6d2706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 35 deletions

View File

@ -112,6 +112,40 @@ where
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
}
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
///
/// # Panics
///
/// If the output size is higher than the current tensor.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
/// let tensor = tensor.unsqueeze::<4>();
/// println!("{:?}", tensor.shape());
/// // Shape { dims: [1, 1, 3, 3] }
/// }
/// ```
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
if D2 < D {
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
}
let mut dims = [1; D2];
let num_ones = D2 - D;
let shape = self.shape();
dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
let shape = Shape::new(dims);
self.reshape(shape)
}
/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Panics

View File

@ -308,41 +308,7 @@ where
Self::new(B::require_grad(self.primitive))
}
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
///
/// # Panics
///
/// If the output size is higher than the current tensor.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
/// let tensor = tensor.unsqueeze::<4>();
/// println!("{:?}", tensor.shape());
/// // Shape { dims: [1, 1, 3, 3] }
/// }
/// ```
/// TODO move this function to the base.
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2> {
if D2 < D {
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
}
let mut dims = [1; D2];
let num_ones = D2 - D;
let shape = self.shape();
dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
let shape = Shape::new(dims);
self.reshape(shape)
}
/// Applies the relu function to the tensor.
pub(crate) fn relu(self) -> Self {
Self::new(B::relu(self.primitive))
}