mirror of https://github.com/tracel-ai/burn.git
Indices Operator (#1735)
This commit is contained in:
parent
cacc764205
commit
a2ad424fc8
|
@ -288,6 +288,7 @@ Those operations are only available for `Int` tensors.
|
|||
| `tensor.float()` | `tensor.to(torch.float)` |
|
||||
| `tensor.from_ints(ints)` | N/A |
|
||||
| `tensor.int_random(shape, distribution, device)` | N/A |
|
||||
| `tensor.cartesian_grid(shape, device)` | N/A |
|
||||
|
||||
# Bool Operations
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
use crate::{backend::Backend, ops::IntTensor, Int, Shape, Tensor};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `D2` is not equal to `D+1`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::Int;
|
||||
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||
/// println!("{}", result);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: usize>(
|
||||
shape: S,
|
||||
device: &B::Device,
|
||||
) -> IntTensor<B, D2> {
|
||||
if D2 != D + 1 {
|
||||
panic!("D2 must equal D + 1 for Tensor::indices")
|
||||
}
|
||||
|
||||
let dims = shape.into().dims;
|
||||
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
|
||||
|
||||
for dim in 0..D {
|
||||
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
|
||||
|
||||
let mut shape = [1; D];
|
||||
shape[dim] = dims[dim];
|
||||
let mut dim_range = dim_range.reshape(shape);
|
||||
|
||||
for (i, &item) in dims.iter().enumerate() {
|
||||
if i == dim {
|
||||
continue;
|
||||
}
|
||||
dim_range = dim_range.repeat(i, item);
|
||||
}
|
||||
|
||||
indices.push(dim_range);
|
||||
}
|
||||
|
||||
Tensor::stack::<D2>(indices, D).into_primitive()
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{backend::Backend, Data, Float, Int, Tensor};
|
||||
use crate::{backend::Backend, Data, Float, Int, Shape, Tensor};
|
||||
|
||||
use core::ops::Range;
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
|
@ -70,6 +71,36 @@ where
|
|||
Tensor::new(B::int_into_float(self.primitive))
|
||||
}
|
||||
|
||||
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `D2` is not equal to `D+1`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::Int;
|
||||
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||
/// println!("{}", result);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn cartesian_grid<S: Into<Shape<D>>, const D2: usize>(
|
||||
shape: S,
|
||||
device: &B::Device,
|
||||
) -> Tensor<B, D2, Int> {
|
||||
Tensor::new(B::int_cartesian_grid::<S, D, D2>(shape, device))
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
|
|
|
@ -4,6 +4,7 @@ mod argwhere;
|
|||
mod autodiff;
|
||||
mod base;
|
||||
mod bool;
|
||||
mod cartesian_grid;
|
||||
mod chunk;
|
||||
mod float;
|
||||
mod int;
|
||||
|
@ -15,6 +16,7 @@ mod sort;
|
|||
pub use argwhere::argwhere;
|
||||
pub use autodiff::*;
|
||||
pub use base::*;
|
||||
pub use cartesian_grid::cartesian_grid;
|
||||
pub use chunk::chunk;
|
||||
pub use kind::*;
|
||||
pub use narrow::narrow;
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use crate::Tensor;
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
|
||||
use crate::{cartesian_grid, Tensor};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
@ -1032,6 +1032,36 @@ pub trait IntTensorOps<B: Backend> {
|
|||
narrow::<B, D, Int>(tensor, dim, start, length)
|
||||
}
|
||||
|
||||
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `D2` is not equal to `D+1`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::Int;
|
||||
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||
/// fn example<B: Backend>() {
|
||||
/// let device = Default::default();
|
||||
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||
/// println!("{}", result);
|
||||
/// }
|
||||
/// ```
|
||||
fn int_cartesian_grid<S: Into<Shape<D>>, const D: usize, const D2: usize>(
|
||||
shape: S,
|
||||
device: &B::Device,
|
||||
) -> IntTensor<B, D2> {
|
||||
cartesian_grid::<B, _, D, D2>(shape, device)
|
||||
}
|
||||
|
||||
/// Split the tensor along the given dimension into chunks.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -97,6 +97,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_sort_argsort!();
|
||||
burn_tensor::testgen_topk!();
|
||||
burn_tensor::testgen_remainder!();
|
||||
burn_tensor::testgen_cartesian_grid!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#[burn_tensor_testgen::testgen(cartesian_grid)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Int, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_cartesian_grid() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
// Test a single element tensor
|
||||
let tensor: Tensor<TestBackend, 2, Int> =
|
||||
Tensor::<TestBackend, 1, Int>::cartesian_grid([1], &device);
|
||||
assert_eq!(tensor.into_data(), Data::from([[0]]));
|
||||
|
||||
// Test for a 2x2 tensor
|
||||
let tensor: Tensor<TestBackend, 3, Int> =
|
||||
Tensor::<TestBackend, 2, Int>::cartesian_grid([2, 2], &device);
|
||||
assert_eq!(
|
||||
tensor.into_data(),
|
||||
Data::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ mod arange_step;
|
|||
mod arg;
|
||||
mod argwhere_nonzero;
|
||||
mod bool;
|
||||
mod cartesian_grid;
|
||||
mod cast;
|
||||
mod cat;
|
||||
mod chunk;
|
||||
|
|
Loading…
Reference in New Issue