Nonzero should return an empty vec for zero tensors (#2212)

* Nonzero should return an empty vec for zero tensors

* Add nonzero empty test

* Add missing import

---------

Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
This commit is contained in:
Guillaume Lagrange 2024-09-03 09:00:58 -04:00 committed by GitHub
parent 96a23408d2
commit cc214d366c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 2 deletions

View File

@ -155,6 +155,8 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
.tensor .tensor
.nonzero_numpy() .nonzero_numpy()
.into_iter() .into_iter()
// As opposed to tch, the resulting vector should be empty for zero tensors
.filter_map(|t| if t.numel() > 0 { Some(t) } else { None })
.map(TchTensor::new) .map(TchTensor::new)
.collect() .collect()
} }

View File

@ -6,7 +6,7 @@ use crate::{
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
TensorData, TensorData,
}; };
use alloc::vec::Vec; use alloc::{vec, vec::Vec};
use core::{future::Future, ops::Range}; use core::{future::Future, ops::Range};
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor) /// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
@ -426,12 +426,18 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns /// # Returns
/// ///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension. /// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
fn bool_nonzero<const D: usize>( fn bool_nonzero<const D: usize>(
tensor: BoolTensor<B, D>, tensor: BoolTensor<B, D>,
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send { ) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
async { async {
let indices = B::bool_argwhere(tensor).await; let indices = B::bool_argwhere(tensor).await;
if B::int_shape(&indices).num_elements() == 0 {
// Return empty vec when all elements are zero
return vec![];
}
let dims = B::int_shape(&indices).dims; let dims = B::int_shape(&indices).dims;
B::int_chunk(indices, dims[1], 1) B::int_chunk(indices, dims[1], 1)
.into_iter() .into_iter()

View File

@ -91,4 +91,12 @@ mod tests {
actual.assert_eq(&data_expected[idx], false) actual.assert_eq(&data_expected[idx], false)
} }
} }
#[test]
fn test_nonzero_empty() {
let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
let output = tensor.nonzero();
assert_eq!(output.len(), 0);
}
} }