feature(tensor): Add chunk op (#998)

This commit is contained in:
David Chavez 2023-11-27 15:58:43 +01:00 committed by GitHub
parent 2fdf9a3591
commit 929b1786bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 123 additions and 0 deletions

View File

@ -50,6 +50,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |

View File

@ -511,6 +511,45 @@ where
self.slice(ranges_array)
}
/// Attempts to split the tensor along the given dimension into chunks.
/// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
///
/// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
/// Otherwise all chunks will be of equal size except for the last one.
///
/// # Panics
///
/// If the dimension is greater than the number of dimensions of the tensor.
///
/// # Returns
/// A vector of tensors.
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::dim_ops::<D>("chunk", dim));
let size = self.shape().dims[dim];
if size < chunks {
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}
let chunk_size = size / chunks;
let cnt_additional = size % chunks;
let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
for i in 0..chunks {
let chunk_size = if i < cnt_additional {
chunk_size + 1
} else {
chunk_size
};
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
tensors
}
}
/// Iterator given by (Tensor::iter_dim).

View File

@ -38,6 +38,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_chunk!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_create_like!();

View File

@ -0,0 +1,81 @@
#[burn_tensor_testgen::testgen(chunk)]
mod tests {
use super::*;
use alloc::vec::Vec;
use burn_tensor::{Data, Int, Shape, Tensor};
fn test_chunk_evenly_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..12).chunk(6, 0);
assert_eq!(tensors.len(), 6);
let expected = vec![
Data::from([0, 1]),
Data::from([2, 3]),
Data::from([4, 5]),
Data::from([6, 7]),
Data::from([8, 9]),
Data::from([10, 11]),
];
for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}
#[test]
fn test_chunk_not_evenly_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..11).chunk(6, 0);
assert_eq!(tensors.len(), 6);
let expected = vec![
Data::from([0, 1]),
Data::from([2, 3]),
Data::from([4, 5]),
Data::from([6, 7]),
Data::from([8, 9]),
Data::from([10]),
];
for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}
#[test]
fn test_chunk_not_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..6).chunk(7, 0);
assert_eq!(tensors.len(), 6);
let expected = vec![
Data::from([0]),
Data::from([1]),
Data::from([2]),
Data::from([3]),
Data::from([4]),
Data::from([5]),
];
for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}
#[test]
fn test_chunk_multi_dimension() {
let tensors: Vec<Tensor<TestBackend, 2, Int>> =
Tensor::from_data(Data::from([[0, 1, 2, 3]])).chunk(2, 1);
assert_eq!(tensors.len(), 2);
let expected = vec![Data::from([[0, 1]]), Data::from([[2, 3]])];
for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}
#[test]
#[should_panic]
fn test_invalid_dim() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..12).chunk(6, 1);
}
}

View File

@ -6,6 +6,7 @@ mod arange_step;
mod arg;
mod cast;
mod cat;
mod chunk;
mod clamp;
mod cos;
mod create_like;