diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index cf9406ee1..408709af6 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -288,6 +288,7 @@ Those operations are only available for `Int` tensors. | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | +| `tensor.cartesian_grid(shape, device)` | N/A | # Bool Operations diff --git a/crates/burn-tensor/src/tensor/api/cartesian_grid.rs b/crates/burn-tensor/src/tensor/api/cartesian_grid.rs new file mode 100644 index 000000000..b03a132cd --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/cartesian_grid.rs @@ -0,0 +1,56 @@ +use crate::{backend::Backend, ops::IntTensor, Int, Shape, Tensor}; +use alloc::vec::Vec; + +/// Generates a cartesian grid for the given tensor shape on the specified device. +/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element. +/// +/// # Arguments +/// +/// * `shape` - The shape specifying the dimensions of the tensor. +/// * `device` - The device to create the tensor on. +/// +/// # Panics +/// +/// Panics if `D2` is not equal to `D+1`. +/// +/// # Examples +/// +/// ```rust +/// use burn_tensor::Int; +/// use burn_tensor::{backend::Backend, Shape, Tensor}; +/// fn example() { +/// let device = Default::default(); +/// let result: Tensor = Tensor::::cartesian_grid([2, 3], &device); +/// println!("{}", result); +/// } +/// ``` +pub fn cartesian_grid>, const D: usize, const D2: usize>( + shape: S, + device: &B::Device, +) -> IntTensor { + if D2 != D + 1 { + panic!("D2 must equal D + 1 for Tensor::indices") + } + + let dims = shape.into().dims; + let mut indices: Vec> = Vec::new(); + + for dim in 0..D { + let dim_range: Tensor = Tensor::arange(0..dims[dim] as i64, device); + + let mut shape = [1; D]; + shape[dim] = dims[dim]; + let mut dim_range = dim_range.reshape(shape); + + for (i, &item) in dims.iter().enumerate() { + if i == dim { + continue; + } + dim_range = dim_range.repeat(i, item); + } + + indices.push(dim_range); + } + + Tensor::stack::(indices, D).into_primitive() +} diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 75060b013..5b0740cd2 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,4 +1,5 @@ -use crate::{backend::Backend, Data, Float, Int, Tensor}; +use crate::{backend::Backend, Data, Float, Int, Shape, Tensor}; + use core::ops::Range; #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] @@ -70,6 +71,36 @@ where Tensor::new(B::int_into_float(self.primitive)) } + /// Generates a cartesian grid for the given tensor shape on the specified device. + /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element. + /// + /// # Arguments + /// + /// * `shape` - The shape specifying the dimensions of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Panics + /// + /// Panics if `D2` is not equal to `D+1`. + /// + /// # Examples + /// + /// ```rust + /// use burn_tensor::Int; + /// use burn_tensor::{backend::Backend, Shape, Tensor}; + /// fn example() { + /// let device = Default::default(); + /// let result: Tensor = Tensor::::cartesian_grid([2, 3], &device); + /// println!("{}", result); + /// } + /// ``` + pub fn cartesian_grid>, const D2: usize>( + shape: S, + device: &B::Device, + ) -> Tensor { + Tensor::new(B::int_cartesian_grid::(shape, device)) + } + /// Sort the elements by value in ascending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index 90d2a7a0a..62118c93a 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -4,6 +4,7 @@ mod argwhere; mod autodiff; mod base; mod bool; +mod cartesian_grid; mod chunk; mod float; mod int; @@ -15,6 +16,7 @@ mod sort; pub use argwhere::argwhere; pub use autodiff::*; pub use base::*; +pub use cartesian_grid::cartesian_grid; pub use chunk::chunk; pub use kind::*; pub use narrow::narrow; diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 8f3fc56d5..be792316a 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1,8 +1,8 @@ use super::cat::cat_with_slice_assign; use super::repeat::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; -use crate::Tensor; use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int}; +use crate::{cartesian_grid, Tensor}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; use burn_common::reader::Reader; @@ -1032,6 +1032,36 @@ pub trait IntTensorOps { narrow::(tensor, dim, start, length) } + /// Generates a cartesian grid for the given tensor shape on the specified device. + /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element. + /// + /// # Arguments + /// + /// * `shape` - The shape specifying the dimensions of the tensor. + /// * `device` - The device to create the tensor on. + /// + /// # Panics + /// + /// Panics if `D2` is not equal to `D+1`. + /// + /// # Examples + /// + /// ```rust + /// use burn_tensor::Int; + /// use burn_tensor::{backend::Backend, Shape, Tensor}; + /// fn example() { + /// let device = Default::default(); + /// let result: Tensor = Tensor::::cartesian_grid([2, 3], &device); + /// println!("{}", result); + /// } + /// ``` + fn int_cartesian_grid>, const D: usize, const D2: usize>( + shape: S, + device: &B::Device, + ) -> IntTensor { + cartesian_grid::(shape, device) + } + /// Split the tensor along the given dimension into chunks. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 527860306..735f8a6e8 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -97,6 +97,7 @@ macro_rules! testgen_all { burn_tensor::testgen_sort_argsort!(); burn_tensor::testgen_topk!(); burn_tensor::testgen_remainder!(); + burn_tensor::testgen_cartesian_grid!(); // test stats burn_tensor::testgen_var!(); diff --git a/crates/burn-tensor/src/tests/ops/cartesian_grid.rs b/crates/burn-tensor/src/tests/ops/cartesian_grid.rs new file mode 100644 index 000000000..f26453410 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/cartesian_grid.rs @@ -0,0 +1,24 @@ +#[burn_tensor_testgen::testgen(cartesian_grid)] +mod tests { + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Int, Shape, Tensor}; + + #[test] + fn test_cartesian_grid() { + let device = ::Device::default(); + + // Test a single element tensor + let tensor: Tensor = + Tensor::::cartesian_grid([1], &device); + assert_eq!(tensor.into_data(), Data::from([[0]])); + + // Test for a 2x2 tensor + let tensor: Tensor = + Tensor::::cartesian_grid([2, 2], &device); + assert_eq!( + tensor.into_data(), + Data::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]) + ); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 1de78ac2a..b6fe1325d 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -8,6 +8,7 @@ mod arange_step; mod arg; mod argwhere_nonzero; mod bool; +mod cartesian_grid; mod cast; mod cat; mod chunk;