mirror of https://github.com/tracel-ai/burn.git
Feat: Add `movedim` tensor operator (#1876)
* ✨ (burn-tensor): add movedim function to tensor API
---------
Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
This commit is contained in:
parent
47a81270e1
commit
b71c300638
|
@ -160,6 +160,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
||||
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
|
||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
|
|
|
@ -171,6 +171,60 @@ where
|
|||
Tensor::new(K::permute(self.primitive, transformed_axes))
|
||||
}
|
||||
|
||||
/// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
|
||||
///
|
||||
/// Other dimensions of input that are not explicitly moved remain in their original order and appear
|
||||
/// at the positions not specified in destination.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
|
||||
/// The values can be negative, in which case they are used as an offset from the end.
|
||||
///
|
||||
/// * `dst` - Destination positions for each of the original dims. These must also be unique.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the source and destination dimensions are not of the same length.
|
||||
/// - If the source and destination vectors contain duplicate values.
|
||||
/// - If the source and destination vectors contain values that are out of bounds.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the dimensions moved.
|
||||
// This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op
|
||||
// for it
|
||||
pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
|
||||
let source_dims = src.into_dim_vec::<D>();
|
||||
let destination_dims = dst.into_dim_vec::<D>();
|
||||
|
||||
check!(TensorCheck::movedim_args_length(
|
||||
&source_dims,
|
||||
&destination_dims
|
||||
));
|
||||
|
||||
let mut m = [-1; D];
|
||||
for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
|
||||
m[d] = s as isize;
|
||||
}
|
||||
let mut axes: [isize; D] = [0; D];
|
||||
let mut source_i = 0;
|
||||
for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
|
||||
*item = if m[dest_i] != -1 {
|
||||
m[dest_i]
|
||||
} else {
|
||||
while source_dims.contains(&source_i) {
|
||||
source_i += 1;
|
||||
}
|
||||
let result = source_i as isize;
|
||||
source_i += 1;
|
||||
result
|
||||
};
|
||||
}
|
||||
|
||||
self.permute(axes)
|
||||
}
|
||||
|
||||
/// Reverse the order of elements in the tensor along the given dimensions.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -1983,6 +2037,67 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
}
|
||||
}
|
||||
|
||||
/// Trait used for movedim arguments
|
||||
pub trait MovedimArgs {
|
||||
/// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
|
||||
fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
|
||||
}
|
||||
|
||||
impl MovedimArgs for Vec<i32> {
|
||||
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
|
||||
let set = self
|
||||
.iter()
|
||||
.map(|&dim| {
|
||||
if dim < 0 {
|
||||
(D as i32 + dim) as usize
|
||||
} else {
|
||||
dim as usize
|
||||
}
|
||||
})
|
||||
.collect::<Vec<usize>>();
|
||||
check!(TensorCheck::movedim_args_vec::<D>(&set));
|
||||
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
impl MovedimArgs for Vec<usize> {
|
||||
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
|
||||
check!(TensorCheck::movedim_args_vec::<D>(&self));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl MovedimArgs for usize {
|
||||
#[allow(clippy::vec_init_then_push)]
|
||||
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
|
||||
check!(TensorCheck::movedim_args_usize::<D>(self));
|
||||
|
||||
let mut set = Vec::with_capacity(1);
|
||||
set.push(self);
|
||||
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
impl MovedimArgs for i32 {
|
||||
#[allow(clippy::vec_init_then_push)]
|
||||
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
|
||||
check!(TensorCheck::movedim_args_i32::<D>(self));
|
||||
|
||||
let dim = if self < 0 {
|
||||
(D as i32 + self) as usize
|
||||
} else {
|
||||
self as usize
|
||||
};
|
||||
|
||||
let mut set = Vec::with_capacity(1);
|
||||
set.push(dim);
|
||||
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait used for reshape arguments.
|
||||
pub trait ReshapeArgs<const D2: usize> {
|
||||
/// Converts to a shape.
|
||||
|
|
|
@ -172,6 +172,98 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn movedim_args_usize<const D: usize>(dim: usize) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if dim >= D {
|
||||
check = check.register(
|
||||
"Movedim",
|
||||
TensorError::new(
|
||||
"The given dimension exceeds the number of dimensions of the current tensor.",
|
||||
)
|
||||
.details(format!(
|
||||
"Current tensor has {D} dimensions, but the given dimension is {dim}.",
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn movedim_args_i32<const D: usize>(dim: i32) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if dim < -(D as i32) || dim >= D as i32 {
|
||||
check = check.register(
|
||||
"Movedim",
|
||||
TensorError::new(
|
||||
"The given dimension is out of bounds for the current tensor dimensions.",
|
||||
)
|
||||
.details(format!(
|
||||
"Current tensor has {D} dimensions, but the given dimension is {dim}.",
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn movedim_args_vec<const D: usize>(dims: &Vec<usize>) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
// Check out of bounds
|
||||
if dims.iter().any(|&x| x >= D) {
|
||||
check = check.register(
|
||||
"Movedim",
|
||||
TensorError::new("The given dimensions are out of bounds.").details(format!(
|
||||
"Current tensor has {D} dimensions, but the given dimensions are {:?}.",
|
||||
dims
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
// Check there are no duplicates
|
||||
for (i, &dim_i) in dims.iter().enumerate() {
|
||||
for &dim_j in dims.iter().skip(i + 1) {
|
||||
if dim_i == dim_j {
|
||||
check = check.register(
|
||||
"Movedim",
|
||||
TensorError::new("The given dimensions contain duplicates.").details(
|
||||
format!(
|
||||
"The dimension {} is duplicated in the given dimensions {:?}.",
|
||||
dim_i, dims
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn movedim_args_length(
|
||||
source_dims: &Vec<usize>,
|
||||
destination_dims: &Vec<usize>,
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if source_dims.len() != destination_dims.len() {
|
||||
check = check.register(
|
||||
"Movedim",
|
||||
TensorError::new(
|
||||
"The number of dimensions in source and destination must be equal.",
|
||||
)
|
||||
.details(format!(
|
||||
"Source dimensions: {:?}, Destination dimensions: {:?}.",
|
||||
source_dims, destination_dims
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn flatten<const D1: usize, const D2: usize>(
|
||||
start_dim: usize,
|
||||
end_dim: usize,
|
||||
|
@ -1104,4 +1196,42 @@ mod tests {
|
|||
&8
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn movedim_args_out_of_bounds() {
|
||||
check!(TensorCheck::movedim_args_usize::<3>(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn movedim_args_i32() {
|
||||
check!(TensorCheck::movedim_args_i32::<3>(-3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn movedim_args_too_negative() {
|
||||
check!(TensorCheck::movedim_args_i32::<3>(-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn movedim_args_vec_out_of_bounds() {
|
||||
check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn movedim_args_vec_duplicates() {
|
||||
check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 1]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn movedim_args_length() {
|
||||
check!(TensorCheck::movedim_args_length(
|
||||
&vec![0, 1],
|
||||
&vec![0, 1, 2]
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,6 +88,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_any!();
|
||||
burn_tensor::testgen_all_op!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_movedim!();
|
||||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
|
|
|
@ -32,6 +32,7 @@ mod map_comparison;
|
|||
mod mask;
|
||||
mod matmul;
|
||||
mod maxmin;
|
||||
mod movedim;
|
||||
mod mul;
|
||||
mod narrow;
|
||||
mod neg;
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
#[burn_tensor_testgen::testgen(movedim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Device, Int, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn normal_int() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
let permuted = tensor.clone().movedim(0, 2);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2)
|
||||
let data_expected = Data::from([
|
||||
[[0, 12], [1, 13], [2, 14], [3, 15]],
|
||||
[[4, 16], [5, 17], [6, 18], [7, 19]],
|
||||
[[8, 20], [9, 21], [10, 22], [11, 23]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axis
|
||||
let permuted = tensor.clone().movedim(0, -1);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axis
|
||||
let permuted = tensor.clone().movedim(0, 0);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normal_float() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.float();
|
||||
|
||||
let permuted = tensor.clone().movedim(0, 2);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float()
|
||||
let data_expected = Data::from([
|
||||
[[0., 12.], [1., 13.], [2., 14.], [3., 15.]],
|
||||
[[4., 16.], [5., 17.], [6., 18.], [7., 19.]],
|
||||
[[8., 20.], [9., 21.], [10., 22.], [11., 23.]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axis
|
||||
let permuted = tensor.clone().movedim(0, -1);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axis
|
||||
let permuted = tensor.clone().movedim(0, 0);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normal_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let permuted = tensor.clone().movedim(0, 2);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).gt(10)
|
||||
let data_expected = Data::from([
|
||||
[[false, true], [false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [false, true], [false, true]],
|
||||
[[false, true], [false, true], [false, true], [true, true]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axis
|
||||
let permuted = tensor.clone().movedim(0, -1);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axis
|
||||
let permuted = tensor.clone().movedim(0, 0);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_input_int() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
|
||||
|
||||
// from pytorch
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0])
|
||||
let data_expected = Data::from([
|
||||
[[0, 1, 2, 3], [12, 13, 14, 15]],
|
||||
[[4, 5, 6, 7], [16, 17, 18, 19]],
|
||||
[[8, 9, 10, 11], [20, 21, 22, 23]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axes
|
||||
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axes
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_input_float() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.float();
|
||||
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
|
||||
|
||||
// from pytorch
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).float()
|
||||
let data_expected = Data::from([
|
||||
[[0., 1., 2., 3.], [12., 13., 14., 15.]],
|
||||
[[4., 5., 6., 7.], [16., 17., 18., 19.]],
|
||||
[[8., 9., 10., 11.], [20., 21., 22., 23.]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axes
|
||||
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axes
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_input_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
|
||||
|
||||
// from pytorch
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).gt(10)
|
||||
let data_expected = Data::from([
|
||||
[[false, false, false, false], [true, true, true, true]],
|
||||
[[false, false, false, false], [true, true, true, true]],
|
||||
[[false, false, false, true], [true, true, true, true]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axes
|
||||
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axes
|
||||
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_input_types() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.float();
|
||||
|
||||
let permuted = tensor.clone().movedim(0_usize, 2_i32);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float()
|
||||
let data_expected = Data::from([
|
||||
[[0., 12.], [1., 13.], [2., 14.], [3., 15.]],
|
||||
[[4., 16.], [5., 17.], [6., 18.], [7., 19.]],
|
||||
[[8., 20.], [9., 21.], [10., 22.], [11., 23.]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with negative axis
|
||||
let permuted = tensor.clone().movedim(0_usize, -1);
|
||||
assert_eq!(data_expected, permuted.into_data());
|
||||
|
||||
// Test with the same axis
|
||||
let permuted = tensor.clone().movedim(0_i32, 0_usize);
|
||||
assert_eq!(tensor.into_data(), permuted.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn edge_different_sizes() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
// Test with a repeated axis
|
||||
let _ = tensor.clone().movedim(vec![0, 1], vec![0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn edge_out_of_bound_axis() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
// Test with an out of bound axis
|
||||
let _ = tensor.clone().movedim(0, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn edge_vec_is_not_a_set() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
// Test with a repeated axis
|
||||
let _ = tensor.clone().movedim(vec![0, 1, 1, 1, 1], vec![0, 0, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn edge_out_of_bound_axis_vec() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
// Test with an out of bound axis
|
||||
let _ = tensor.clone().movedim(vec![0, 100], vec![0, 1]);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue