Nonzero should return an empty vec for zero tensors (#2212)

* Nonzero should return an empty vec for zero tensors

* Add nonzero empty test

* Add missing import

---------

Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
This commit is contained in:
Guillaume Lagrange 2024-09-03 09:00:58 -04:00 committed by GitHub
parent 96a23408d2
commit cc214d366c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 2 deletions

View File

@ -155,6 +155,8 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
.tensor
.nonzero_numpy()
.into_iter()
// As opposed to tch, the resulting vector should be empty for zero tensors
.filter_map(|t| if t.numel() > 0 { Some(t) } else { None })
.map(TchTensor::new)
.collect()
}

View File

@ -6,7 +6,7 @@ use crate::{
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
TensorData,
};
use alloc::vec::Vec;
use alloc::{vec, vec::Vec};
use core::{future::Future, ops::Range};
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
@ -426,12 +426,18 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
/// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
fn bool_nonzero<const D: usize>(
tensor: BoolTensor<B, D>,
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
async {
let indices = B::bool_argwhere(tensor).await;
if B::int_shape(&indices).num_elements() == 0 {
// Return empty vec when all elements are zero
return vec![];
}
let dims = B::int_shape(&indices).dims;
B::int_chunk(indices, dims[1], 1)
.into_iter()

View File

@ -91,4 +91,12 @@ mod tests {
actual.assert_eq(&data_expected[idx], false)
}
}
#[test]
fn test_nonzero_empty() {
let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
let output = tensor.nonzero();
assert_eq!(output.len(), 0);
}
}