mirror of https://github.com/tracel-ai/burn.git
feature(tensor): Add chunk op (#998)
This commit is contained in:
parent
2fdf9a3591
commit
929b1786bb
|
@ -50,6 +50,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
||||
|
|
|
@ -511,6 +511,45 @@ where
|
|||
|
||||
self.slice(ranges_array)
|
||||
}
|
||||
|
||||
/// Attempts to split the tensor along the given dimension into chunks.
|
||||
/// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
|
||||
///
|
||||
/// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
|
||||
/// Otherwise all chunks will be of equal size except for the last one.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the dimension is greater than the number of dimensions of the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
/// A vector of tensors.
|
||||
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
|
||||
check!(TensorCheck::dim_ops::<D>("chunk", dim));
|
||||
|
||||
let size = self.shape().dims[dim];
|
||||
if size < chunks {
|
||||
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
|
||||
}
|
||||
|
||||
let chunk_size = size / chunks;
|
||||
let cnt_additional = size % chunks;
|
||||
let mut tensors = Vec::with_capacity(chunks);
|
||||
|
||||
let mut sum_chunk_size = 0;
|
||||
for i in 0..chunks {
|
||||
let chunk_size = if i < cnt_additional {
|
||||
chunk_size + 1
|
||||
} else {
|
||||
chunk_size
|
||||
};
|
||||
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator given by (Tensor::iter_dim).
|
||||
|
|
|
@ -38,6 +38,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cast!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_chunk!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_create_like!();
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
#[burn_tensor_testgen::testgen(chunk)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Data, Int, Shape, Tensor};
|
||||
|
||||
fn test_chunk_evenly_divisible() {
|
||||
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..12).chunk(6, 0);
|
||||
assert_eq!(tensors.len(), 6);
|
||||
|
||||
let expected = vec![
|
||||
Data::from([0, 1]),
|
||||
Data::from([2, 3]),
|
||||
Data::from([4, 5]),
|
||||
Data::from([6, 7]),
|
||||
Data::from([8, 9]),
|
||||
Data::from([10, 11]),
|
||||
];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
assert_eq!(tensor.to_data(), expected[index]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_evenly_divisible() {
|
||||
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..11).chunk(6, 0);
|
||||
assert_eq!(tensors.len(), 6);
|
||||
|
||||
let expected = vec![
|
||||
Data::from([0, 1]),
|
||||
Data::from([2, 3]),
|
||||
Data::from([4, 5]),
|
||||
Data::from([6, 7]),
|
||||
Data::from([8, 9]),
|
||||
Data::from([10]),
|
||||
];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
assert_eq!(tensor.to_data(), expected[index]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_divisible() {
|
||||
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..6).chunk(7, 0);
|
||||
assert_eq!(tensors.len(), 6);
|
||||
|
||||
let expected = vec![
|
||||
Data::from([0]),
|
||||
Data::from([1]),
|
||||
Data::from([2]),
|
||||
Data::from([3]),
|
||||
Data::from([4]),
|
||||
Data::from([5]),
|
||||
];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
assert_eq!(tensor.to_data(), expected[index]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_multi_dimension() {
|
||||
let tensors: Vec<Tensor<TestBackend, 2, Int>> =
|
||||
Tensor::from_data(Data::from([[0, 1, 2, 3]])).chunk(2, 1);
|
||||
assert_eq!(tensors.len(), 2);
|
||||
|
||||
let expected = vec![Data::from([[0, 1]]), Data::from([[2, 3]])];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
assert_eq!(tensor.to_data(), expected[index]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_invalid_dim() {
|
||||
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..12).chunk(6, 1);
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ mod arange_step;
|
|||
mod arg;
|
||||
mod cast;
|
||||
mod cat;
|
||||
mod chunk;
|
||||
mod clamp;
|
||||
mod cos;
|
||||
mod create_like;
|
||||
|
|
Loading…
Reference in New Issue