mirror of https://github.com/tracel-ai/burn.git
Move unsqueeze op to the tensor's base (#261)
This commit is contained in:
parent
32d38bebc3
commit
4e9e6d2706
|
@ -112,6 +112,40 @@ where
|
||||||
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
|
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// If the output size is higher than the current tensor.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use burn_tensor::backend::Backend;
|
||||||
|
/// use burn_tensor::{Tensor, Shape};
|
||||||
|
///
|
||||||
|
/// fn example<B: Backend>() {
|
||||||
|
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
|
||||||
|
/// let tensor = tensor.unsqueeze::<4>();
|
||||||
|
/// println!("{:?}", tensor.shape());
|
||||||
|
/// // Shape { dims: [1, 1, 3, 3] }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
|
||||||
|
if D2 < D {
|
||||||
|
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut dims = [1; D2];
|
||||||
|
let num_ones = D2 - D;
|
||||||
|
let shape = self.shape();
|
||||||
|
|
||||||
|
dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
|
||||||
|
|
||||||
|
let shape = Shape::new(dims);
|
||||||
|
self.reshape(shape)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a tensor containing the elements selected from the given ranges.
|
/// Returns a tensor containing the elements selected from the given ranges.
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
|
|
|
@ -308,41 +308,7 @@ where
|
||||||
Self::new(B::require_grad(self.primitive))
|
Self::new(B::require_grad(self.primitive))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
|
/// Applies the relu function to the tensor.
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
///
|
|
||||||
/// If the output size is higher than the current tensor.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use burn_tensor::backend::Backend;
|
|
||||||
/// use burn_tensor::{Tensor, Shape};
|
|
||||||
///
|
|
||||||
/// fn example<B: Backend>() {
|
|
||||||
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]));
|
|
||||||
/// let tensor = tensor.unsqueeze::<4>();
|
|
||||||
/// println!("{:?}", tensor.shape());
|
|
||||||
/// // Shape { dims: [1, 1, 3, 3] }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
/// TODO move this function to the base.
|
|
||||||
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2> {
|
|
||||||
if D2 < D {
|
|
||||||
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut dims = [1; D2];
|
|
||||||
let num_ones = D2 - D;
|
|
||||||
let shape = self.shape();
|
|
||||||
|
|
||||||
dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
|
|
||||||
|
|
||||||
let shape = Shape::new(dims);
|
|
||||||
self.reshape(shape)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn relu(self) -> Self {
|
pub(crate) fn relu(self) -> Self {
|
||||||
Self::new(B::relu(self.primitive))
|
Self::new(B::relu(self.primitive))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue