diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index 8c6323652..7db4a1b39 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -112,6 +112,40 @@ where Tensor::new(K::reshape::(self.primitive, new_dims.into())) } + /// Unsqueeze the current tensor. Create new dimensions to fit the given size. + /// + /// # Panics + /// + /// If the output size is higher than the current tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let tensor = Tensor::::ones(Shape::new([3, 3])); + /// let tensor = tensor.unsqueeze::<4>(); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [1, 1, 3, 3] } + /// } + /// ``` + pub fn unsqueeze(self) -> Tensor { + if D2 < D { + panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}") + } + + let mut dims = [1; D2]; + let num_ones = D2 - D; + let shape = self.shape(); + + dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]); + + let shape = Shape::new(dims); + self.reshape(shape) + } + /// Returns a tensor containing the elements selected from the given ranges. /// /// # Panics diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index dd36448c5..418456df3 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -308,41 +308,7 @@ where Self::new(B::require_grad(self.primitive)) } - /// Unsqueeze the current tensor. Create new dimensions to fit the given size. - /// - /// # Panics - /// - /// If the output size is higher than the current tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let tensor = Tensor::::ones(Shape::new([3, 3])); - /// let tensor = tensor.unsqueeze::<4>(); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [1, 1, 3, 3] } - /// } - /// ``` - /// TODO move this function to the base. - pub fn unsqueeze(self) -> Tensor { - if D2 < D { - panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}") - } - - let mut dims = [1; D2]; - let num_ones = D2 - D; - let shape = self.shape(); - - dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]); - - let shape = Shape::new(dims); - self.reshape(shape) - } - + /// Applies the relu function to the tensor. pub(crate) fn relu(self) -> Self { Self::new(B::relu(self.primitive)) }