fix/docs/chunk (#1006)

This commit is contained in:
Louis Fortier-Dubois 2023-11-29 09:38:43 -05:00 committed by GitHub
parent 3301aedcc3
commit aa3180d0c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 12 deletions

View File

@ -532,20 +532,22 @@ where
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}
let chunk_size = size / chunks;
let cnt_additional = size % chunks;
let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
for i in 0..chunks {
let chunk_size = if i < cnt_additional {
chunk_size + 1
} else {
chunk_size
};
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
}
tensors

View File

@ -41,6 +41,18 @@ mod tests {
}
}
#[test]
fn test_chunk_not_evenly_divisible_remains_several() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..100).chunk(8, 0);
assert_eq!(tensors.len(), 8);
let expected = [13, 13, 13, 13, 13, 13, 13, 9];
for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.shape().dims[0], expected[index]);
}
}
#[test]
fn test_chunk_not_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..6).chunk(7, 0);