Indices Operator (#1735)

This commit is contained in:
McArthur 2024-05-29 13:05:31 +00:00 committed by GitHub
parent cacc764205
commit a2ad424fc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 148 additions and 2 deletions

View File

@ -288,6 +288,7 @@ Those operations are only available for `Int` tensors.
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
| `tensor.cartesian_grid(shape, device)` | N/A |
# Bool Operations

View File

@ -0,0 +1,56 @@
use crate::{backend::Backend, ops::IntTensor, Int, Shape, Tensor};
use alloc::vec::Vec;
/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
/// use burn_tensor::Int;
/// use burn_tensor::{backend::Backend, Shape, Tensor};
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
/// println!("{}", result);
/// }
/// ```
pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: usize>(
shape: S,
device: &B::Device,
) -> IntTensor<B, D2> {
if D2 != D + 1 {
panic!("D2 must equal D + 1 for Tensor::indices")
}
let dims = shape.into().dims;
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
for dim in 0..D {
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
let mut shape = [1; D];
shape[dim] = dims[dim];
let mut dim_range = dim_range.reshape(shape);
for (i, &item) in dims.iter().enumerate() {
if i == dim {
continue;
}
dim_range = dim_range.repeat(i, item);
}
indices.push(dim_range);
}
Tensor::stack::<D2>(indices, D).into_primitive()
}

View File

@ -1,4 +1,5 @@
use crate::{backend::Backend, Data, Float, Int, Tensor};
use crate::{backend::Backend, Data, Float, Int, Shape, Tensor};
use core::ops::Range;
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
@ -70,6 +71,36 @@ where
Tensor::new(B::int_into_float(self.primitive))
}
/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
/// use burn_tensor::Int;
/// use burn_tensor::{backend::Backend, Shape, Tensor};
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
/// println!("{}", result);
/// }
/// ```
pub fn cartesian_grid<S: Into<Shape<D>>, const D2: usize>(
shape: S,
device: &B::Device,
) -> Tensor<B, D2, Int> {
Tensor::new(B::int_cartesian_grid::<S, D, D2>(shape, device))
}
/// Sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).

View File

@ -4,6 +4,7 @@ mod argwhere;
mod autodiff;
mod base;
mod bool;
mod cartesian_grid;
mod chunk;
mod float;
mod int;
@ -15,6 +16,7 @@ mod sort;
pub use argwhere::argwhere;
pub use autodiff::*;
pub use base::*;
pub use cartesian_grid::cartesian_grid;
pub use chunk::chunk;
pub use kind::*;
pub use narrow::narrow;

View File

@ -1,8 +1,8 @@
use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::Tensor;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
use crate::{cartesian_grid, Tensor};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
@ -1032,6 +1032,36 @@ pub trait IntTensorOps<B: Backend> {
narrow::<B, D, Int>(tensor, dim, start, length)
}
/// Generates a cartesian grid for the given tensor shape on the specified device.
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
///
/// # Arguments
///
/// * `shape` - The shape specifying the dimensions of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Panics
///
/// Panics if `D2` is not equal to `D+1`.
///
/// # Examples
///
/// ```rust
/// use burn_tensor::Int;
/// use burn_tensor::{backend::Backend, Shape, Tensor};
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
/// println!("{}", result);
/// }
/// ```
fn int_cartesian_grid<S: Into<Shape<D>>, const D: usize, const D2: usize>(
shape: S,
device: &B::Device,
) -> IntTensor<B, D2> {
cartesian_grid::<B, _, D, D2>(shape, device)
}
/// Split the tensor along the given dimension into chunks.
///
/// # Arguments

View File

@ -97,6 +97,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_sort_argsort!();
burn_tensor::testgen_topk!();
burn_tensor::testgen_remainder!();
burn_tensor::testgen_cartesian_grid!();
// test stats
burn_tensor::testgen_var!();

View File

@ -0,0 +1,24 @@
#[burn_tensor_testgen::testgen(cartesian_grid)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Int, Shape, Tensor};
#[test]
fn test_cartesian_grid() {
let device = <TestBackend as Backend>::Device::default();
// Test a single element tensor
let tensor: Tensor<TestBackend, 2, Int> =
Tensor::<TestBackend, 1, Int>::cartesian_grid([1], &device);
assert_eq!(tensor.into_data(), Data::from([[0]]));
// Test for a 2x2 tensor
let tensor: Tensor<TestBackend, 3, Int> =
Tensor::<TestBackend, 2, Int>::cartesian_grid([2, 2], &device);
assert_eq!(
tensor.into_data(),
Data::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
);
}
}

View File

@ -8,6 +8,7 @@ mod arange_step;
mod arg;
mod argwhere_nonzero;
mod bool;
mod cartesian_grid;
mod cast;
mod cat;
mod chunk;