mirror of https://github.com/tracel-ai/burn.git
Indices Operator (#1735)
This commit is contained in:
parent
cacc764205
commit
a2ad424fc8
|
@ -288,6 +288,7 @@ Those operations are only available for `Int` tensors.
|
||||||
| `tensor.float()` | `tensor.to(torch.float)` |
|
| `tensor.float()` | `tensor.to(torch.float)` |
|
||||||
| `tensor.from_ints(ints)` | N/A |
|
| `tensor.from_ints(ints)` | N/A |
|
||||||
| `tensor.int_random(shape, distribution, device)` | N/A |
|
| `tensor.int_random(shape, distribution, device)` | N/A |
|
||||||
|
| `tensor.cartesian_grid(shape, device)` | N/A |
|
||||||
|
|
||||||
# Bool Operations
|
# Bool Operations
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
use crate::{backend::Backend, ops::IntTensor, Int, Shape, Tensor};
|
||||||
|
use alloc::vec::Vec;
|
||||||
|
|
||||||
|
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||||
|
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||||
|
/// * `device` - The device to create the tensor on.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `D2` is not equal to `D+1`.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use burn_tensor::Int;
|
||||||
|
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||||
|
/// fn example<B: Backend>() {
|
||||||
|
/// let device = Default::default();
|
||||||
|
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||||
|
/// println!("{}", result);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: usize>(
|
||||||
|
shape: S,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> IntTensor<B, D2> {
|
||||||
|
if D2 != D + 1 {
|
||||||
|
panic!("D2 must equal D + 1 for Tensor::indices")
|
||||||
|
}
|
||||||
|
|
||||||
|
let dims = shape.into().dims;
|
||||||
|
let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();
|
||||||
|
|
||||||
|
for dim in 0..D {
|
||||||
|
let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);
|
||||||
|
|
||||||
|
let mut shape = [1; D];
|
||||||
|
shape[dim] = dims[dim];
|
||||||
|
let mut dim_range = dim_range.reshape(shape);
|
||||||
|
|
||||||
|
for (i, &item) in dims.iter().enumerate() {
|
||||||
|
if i == dim {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
dim_range = dim_range.repeat(i, item);
|
||||||
|
}
|
||||||
|
|
||||||
|
indices.push(dim_range);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::stack::<D2>(indices, D).into_primitive()
|
||||||
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::{backend::Backend, Data, Float, Int, Tensor};
|
use crate::{backend::Backend, Data, Float, Int, Shape, Tensor};
|
||||||
|
|
||||||
use core::ops::Range;
|
use core::ops::Range;
|
||||||
|
|
||||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||||
|
@ -70,6 +71,36 @@ where
|
||||||
Tensor::new(B::int_into_float(self.primitive))
|
Tensor::new(B::int_into_float(self.primitive))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||||
|
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||||
|
/// * `device` - The device to create the tensor on.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `D2` is not equal to `D+1`.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use burn_tensor::Int;
|
||||||
|
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||||
|
/// fn example<B: Backend>() {
|
||||||
|
/// let device = Default::default();
|
||||||
|
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||||
|
/// println!("{}", result);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn cartesian_grid<S: Into<Shape<D>>, const D2: usize>(
|
||||||
|
shape: S,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Tensor<B, D2, Int> {
|
||||||
|
Tensor::new(B::int_cartesian_grid::<S, D, D2>(shape, device))
|
||||||
|
}
|
||||||
|
|
||||||
/// Sort the elements by value in ascending order along a given dimension.
|
/// Sort the elements by value in ascending order along a given dimension.
|
||||||
///
|
///
|
||||||
/// This sort is unstable (i.e., may reorder equal elements).
|
/// This sort is unstable (i.e., may reorder equal elements).
|
||||||
|
|
|
@ -4,6 +4,7 @@ mod argwhere;
|
||||||
mod autodiff;
|
mod autodiff;
|
||||||
mod base;
|
mod base;
|
||||||
mod bool;
|
mod bool;
|
||||||
|
mod cartesian_grid;
|
||||||
mod chunk;
|
mod chunk;
|
||||||
mod float;
|
mod float;
|
||||||
mod int;
|
mod int;
|
||||||
|
@ -15,6 +16,7 @@ mod sort;
|
||||||
pub use argwhere::argwhere;
|
pub use argwhere::argwhere;
|
||||||
pub use autodiff::*;
|
pub use autodiff::*;
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
|
pub use cartesian_grid::cartesian_grid;
|
||||||
pub use chunk::chunk;
|
pub use chunk::chunk;
|
||||||
pub use kind::*;
|
pub use kind::*;
|
||||||
pub use narrow::narrow;
|
pub use narrow::narrow;
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use super::cat::cat_with_slice_assign;
|
use super::cat::cat_with_slice_assign;
|
||||||
use super::repeat::repeat_with_slice_assign;
|
use super::repeat::repeat_with_slice_assign;
|
||||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||||
use crate::Tensor;
|
|
||||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
|
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
|
||||||
|
use crate::{cartesian_grid, Tensor};
|
||||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
use burn_common::reader::Reader;
|
use burn_common::reader::Reader;
|
||||||
|
@ -1032,6 +1032,36 @@ pub trait IntTensorOps<B: Backend> {
|
||||||
narrow::<B, D, Int>(tensor, dim, start, length)
|
narrow::<B, D, Int>(tensor, dim, start, length)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a cartesian grid for the given tensor shape on the specified device.
|
||||||
|
/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `shape` - The shape specifying the dimensions of the tensor.
|
||||||
|
/// * `device` - The device to create the tensor on.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `D2` is not equal to `D+1`.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use burn_tensor::Int;
|
||||||
|
/// use burn_tensor::{backend::Backend, Shape, Tensor};
|
||||||
|
/// fn example<B: Backend>() {
|
||||||
|
/// let device = Default::default();
|
||||||
|
/// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
|
||||||
|
/// println!("{}", result);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
fn int_cartesian_grid<S: Into<Shape<D>>, const D: usize, const D2: usize>(
|
||||||
|
shape: S,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> IntTensor<B, D2> {
|
||||||
|
cartesian_grid::<B, _, D, D2>(shape, device)
|
||||||
|
}
|
||||||
|
|
||||||
/// Split the tensor along the given dimension into chunks.
|
/// Split the tensor along the given dimension into chunks.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
|
@ -97,6 +97,7 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_sort_argsort!();
|
burn_tensor::testgen_sort_argsort!();
|
||||||
burn_tensor::testgen_topk!();
|
burn_tensor::testgen_topk!();
|
||||||
burn_tensor::testgen_remainder!();
|
burn_tensor::testgen_remainder!();
|
||||||
|
burn_tensor::testgen_cartesian_grid!();
|
||||||
|
|
||||||
// test stats
|
// test stats
|
||||||
burn_tensor::testgen_var!();
|
burn_tensor::testgen_var!();
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
#[burn_tensor_testgen::testgen(cartesian_grid)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::backend::Backend;
|
||||||
|
use burn_tensor::{Data, Int, Shape, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cartesian_grid() {
|
||||||
|
let device = <TestBackend as Backend>::Device::default();
|
||||||
|
|
||||||
|
// Test a single element tensor
|
||||||
|
let tensor: Tensor<TestBackend, 2, Int> =
|
||||||
|
Tensor::<TestBackend, 1, Int>::cartesian_grid([1], &device);
|
||||||
|
assert_eq!(tensor.into_data(), Data::from([[0]]));
|
||||||
|
|
||||||
|
// Test for a 2x2 tensor
|
||||||
|
let tensor: Tensor<TestBackend, 3, Int> =
|
||||||
|
Tensor::<TestBackend, 2, Int>::cartesian_grid([2, 2], &device);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.into_data(),
|
||||||
|
Data::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ mod arange_step;
|
||||||
mod arg;
|
mod arg;
|
||||||
mod argwhere_nonzero;
|
mod argwhere_nonzero;
|
||||||
mod bool;
|
mod bool;
|
||||||
|
mod cartesian_grid;
|
||||||
mod cast;
|
mod cast;
|
||||||
mod cat;
|
mod cat;
|
||||||
mod chunk;
|
mod chunk;
|
||||||
|
|
Loading…
Reference in New Issue