mirror of https://github.com/tracel-ai/burn.git
fix/docs/chunk (#1006)
This commit is contained in:
parent
3301aedcc3
commit
aa3180d0c7
|
@ -532,20 +532,22 @@ where
|
|||
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
|
||||
}
|
||||
|
||||
let chunk_size = size / chunks;
|
||||
let cnt_additional = size % chunks;
|
||||
let mut tensors = Vec::with_capacity(chunks);
|
||||
|
||||
let mut sum_chunk_size = 0;
|
||||
for i in 0..chunks {
|
||||
let chunk_size = if i < cnt_additional {
|
||||
chunk_size + 1
|
||||
} else {
|
||||
chunk_size
|
||||
};
|
||||
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
|
||||
sum_chunk_size += chunk_size;
|
||||
if size % chunks == 0 {
|
||||
let chunk_size = size / chunks;
|
||||
for _ in 0..chunks {
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
} else {
|
||||
let chunk_size = (size / chunks) + 1; // assumes not divisible
|
||||
for _ in 0..chunks - 1 {
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
|
||||
sum_chunk_size += chunk_size;
|
||||
}
|
||||
let remainder = size % chunk_size;
|
||||
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
|
||||
}
|
||||
|
||||
tensors
|
||||
|
|
|
@ -41,6 +41,18 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_evenly_divisible_remains_several() {
|
||||
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..100).chunk(8, 0);
|
||||
assert_eq!(tensors.len(), 8);
|
||||
|
||||
let expected = [13, 13, 13, 13, 13, 13, 13, 9];
|
||||
|
||||
for (index, tensor) in tensors.iter().enumerate() {
|
||||
assert_eq!(tensor.shape().dims[0], expected[index]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_not_divisible() {
|
||||
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..6).chunk(7, 0);
|
||||
|
|
Loading…
Reference in New Issue