Feat: Add `movedim` tensor operator (#1876)

*  (burn-tensor): add movedim function to tensor API

---------

Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
This commit is contained in:
George 2024-06-14 17:01:38 +04:00 committed by GitHub
parent 47a81270e1
commit b71c300638
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 486 additions and 0 deletions

View File

@ -160,6 +160,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |

View File

@ -171,6 +171,60 @@ where
Tensor::new(K::permute(self.primitive, transformed_axes))
}
/// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
///
/// Other dimensions of input that are not explicitly moved remain in their original order and appear
/// at the positions not specified in destination.
///
/// # Arguments
///
/// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
/// The values can be negative, in which case they are used as an offset from the end.
///
/// * `dst` - Destination positions for each of the original dims. These must also be unique.
///
/// # Panics
///
/// - If the source and destination dimensions are not of the same length.
/// - If the source and destination vectors contain duplicate values.
/// - If the source and destination vectors contain values that are out of bounds.
///
/// # Returns
///
/// The tensor with the dimensions moved.
// This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op
// for it
pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
let source_dims = src.into_dim_vec::<D>();
let destination_dims = dst.into_dim_vec::<D>();
check!(TensorCheck::movedim_args_length(
&source_dims,
&destination_dims
));
let mut m = [-1; D];
for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
m[d] = s as isize;
}
let mut axes: [isize; D] = [0; D];
let mut source_i = 0;
for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
*item = if m[dest_i] != -1 {
m[dest_i]
} else {
while source_dims.contains(&source_i) {
source_i += 1;
}
let result = source_i as isize;
source_i += 1;
result
};
}
self.permute(axes)
}
/// Reverse the order of elements in the tensor along the given dimensions.
///
/// # Arguments
@ -1983,6 +2037,67 @@ impl<B: Backend> BasicOps<B> for Bool {
}
}
/// Trait used for movedim arguments
pub trait MovedimArgs {
/// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
}
impl MovedimArgs for Vec<i32> {
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
let set = self
.iter()
.map(|&dim| {
if dim < 0 {
(D as i32 + dim) as usize
} else {
dim as usize
}
})
.collect::<Vec<usize>>();
check!(TensorCheck::movedim_args_vec::<D>(&set));
set
}
}
impl MovedimArgs for Vec<usize> {
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
check!(TensorCheck::movedim_args_vec::<D>(&self));
self
}
}
impl MovedimArgs for usize {
#[allow(clippy::vec_init_then_push)]
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
check!(TensorCheck::movedim_args_usize::<D>(self));
let mut set = Vec::with_capacity(1);
set.push(self);
set
}
}
impl MovedimArgs for i32 {
#[allow(clippy::vec_init_then_push)]
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
check!(TensorCheck::movedim_args_i32::<D>(self));
let dim = if self < 0 {
(D as i32 + self) as usize
} else {
self as usize
};
let mut set = Vec::with_capacity(1);
set.push(dim);
set
}
}
/// Trait used for reshape arguments.
pub trait ReshapeArgs<const D2: usize> {
/// Converts to a shape.

View File

@ -172,6 +172,98 @@ impl TensorCheck {
check
}
pub(crate) fn movedim_args_usize<const D: usize>(dim: usize) -> Self {
let mut check = Self::Ok;
if dim >= D {
check = check.register(
"Movedim",
TensorError::new(
"The given dimension exceeds the number of dimensions of the current tensor.",
)
.details(format!(
"Current tensor has {D} dimensions, but the given dimension is {dim}.",
)),
);
}
check
}
pub(crate) fn movedim_args_i32<const D: usize>(dim: i32) -> Self {
let mut check = Self::Ok;
if dim < -(D as i32) || dim >= D as i32 {
check = check.register(
"Movedim",
TensorError::new(
"The given dimension is out of bounds for the current tensor dimensions.",
)
.details(format!(
"Current tensor has {D} dimensions, but the given dimension is {dim}.",
)),
);
}
check
}
pub(crate) fn movedim_args_vec<const D: usize>(dims: &Vec<usize>) -> Self {
let mut check = Self::Ok;
// Check out of bounds
if dims.iter().any(|&x| x >= D) {
check = check.register(
"Movedim",
TensorError::new("The given dimensions are out of bounds.").details(format!(
"Current tensor has {D} dimensions, but the given dimensions are {:?}.",
dims
)),
);
}
// Check there are no duplicates
for (i, &dim_i) in dims.iter().enumerate() {
for &dim_j in dims.iter().skip(i + 1) {
if dim_i == dim_j {
check = check.register(
"Movedim",
TensorError::new("The given dimensions contain duplicates.").details(
format!(
"The dimension {} is duplicated in the given dimensions {:?}.",
dim_i, dims
),
),
);
}
}
}
check
}
pub(crate) fn movedim_args_length(
source_dims: &Vec<usize>,
destination_dims: &Vec<usize>,
) -> Self {
let mut check = Self::Ok;
if source_dims.len() != destination_dims.len() {
check = check.register(
"Movedim",
TensorError::new(
"The number of dimensions in source and destination must be equal.",
)
.details(format!(
"Source dimensions: {:?}, Destination dimensions: {:?}.",
source_dims, destination_dims
)),
)
}
check
}
pub(crate) fn flatten<const D1: usize, const D2: usize>(
start_dim: usize,
end_dim: usize,
@ -1104,4 +1196,42 @@ mod tests {
&8
));
}
#[test]
#[should_panic]
fn movedim_args_out_of_bounds() {
check!(TensorCheck::movedim_args_usize::<3>(5));
}
#[test]
fn movedim_args_i32() {
check!(TensorCheck::movedim_args_i32::<3>(-3));
}
#[test]
#[should_panic]
fn movedim_args_too_negative() {
check!(TensorCheck::movedim_args_i32::<3>(-4));
}
#[test]
#[should_panic]
fn movedim_args_vec_out_of_bounds() {
check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 3]));
}
#[test]
#[should_panic]
fn movedim_args_vec_duplicates() {
check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 1]));
}
#[test]
#[should_panic]
fn movedim_args_length() {
check!(TensorCheck::movedim_args_length(
&vec![0, 1],
&vec![0, 1, 2]
));
}
}

View File

@ -88,6 +88,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_any!();
burn_tensor::testgen_all_op!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_movedim!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();

View File

@ -32,6 +32,7 @@ mod map_comparison;
mod mask;
mod matmul;
mod maxmin;
mod movedim;
mod mul;
mod narrow;
mod neg;

View File

@ -0,0 +1,238 @@
#[burn_tensor_testgen::testgen(movedim)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Device, Int, Shape, Tensor};
#[test]
fn normal_int() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
let permuted = tensor.clone().movedim(0, 2);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2)
let data_expected = Data::from([
[[0, 12], [1, 13], [2, 14], [3, 15]],
[[4, 16], [5, 17], [6, 18], [7, 19]],
[[8, 20], [9, 21], [10, 22], [11, 23]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axis
let permuted = tensor.clone().movedim(0, -1);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axis
let permuted = tensor.clone().movedim(0, 0);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
fn normal_float() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.float();
let permuted = tensor.clone().movedim(0, 2);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float()
let data_expected = Data::from([
[[0., 12.], [1., 13.], [2., 14.], [3., 15.]],
[[4., 16.], [5., 17.], [6., 18.], [7., 19.]],
[[8., 20.], [9., 21.], [10., 22.], [11., 23.]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axis
let permuted = tensor.clone().movedim(0, -1);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axis
let permuted = tensor.clone().movedim(0, 0);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
fn normal_bool() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.greater_elem(10);
let permuted = tensor.clone().movedim(0, 2);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).gt(10)
let data_expected = Data::from([
[[false, true], [false, true], [false, true], [false, true]],
[[false, true], [false, true], [false, true], [false, true]],
[[false, true], [false, true], [false, true], [true, true]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axis
let permuted = tensor.clone().movedim(0, -1);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axis
let permuted = tensor.clone().movedim(0, 0);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
fn vec_input_int() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
// from pytorch
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0])
let data_expected = Data::from([
[[0, 1, 2, 3], [12, 13, 14, 15]],
[[4, 5, 6, 7], [16, 17, 18, 19]],
[[8, 9, 10, 11], [20, 21, 22, 23]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axes
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axes
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
fn vec_input_float() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.float();
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
// from pytorch
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).float()
let data_expected = Data::from([
[[0., 1., 2., 3.], [12., 13., 14., 15.]],
[[4., 5., 6., 7.], [16., 17., 18., 19.]],
[[8., 9., 10., 11.], [20., 21., 22., 23.]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axes
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axes
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
fn vec_input_bool() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.greater_elem(10);
let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);
// from pytorch
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).gt(10)
let data_expected = Data::from([
[[false, false, false, false], [true, true, true, true]],
[[false, false, false, false], [true, true, true, true]],
[[false, false, false, true], [true, true, true, true]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axes
let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axes
let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
fn different_input_types() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.float();
let permuted = tensor.clone().movedim(0_usize, 2_i32);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float()
let data_expected = Data::from([
[[0., 12.], [1., 13.], [2., 14.], [3., 15.]],
[[4., 16.], [5., 17.], [6., 18.], [7., 19.]],
[[8., 20.], [9., 21.], [10., 22.], [11., 23.]],
]);
assert_eq!(data_expected, permuted.into_data());
// Test with negative axis
let permuted = tensor.clone().movedim(0_usize, -1);
assert_eq!(data_expected, permuted.into_data());
// Test with the same axis
let permuted = tensor.clone().movedim(0_i32, 0_usize);
assert_eq!(tensor.into_data(), permuted.into_data());
}
#[test]
#[should_panic]
fn edge_different_sizes() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
// Test with a repeated axis
let _ = tensor.clone().movedim(vec![0, 1], vec![0]);
}
#[test]
#[should_panic]
fn edge_out_of_bound_axis() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
// Test with an out of bound axis
let _ = tensor.clone().movedim(0, 100);
}
#[test]
#[should_panic]
fn edge_vec_is_not_a_set() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
// Test with a repeated axis
let _ = tensor.clone().movedim(vec![0, 1, 1, 1, 1], vec![0, 0, 1]);
}
#[test]
#[should_panic]
fn edge_out_of_bound_axis_vec() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
// Test with an out of bound axis
let _ = tensor.clone().movedim(vec![0, 100], vec![0, 1]);
}
}