mirror of https://github.com/tracel-ai/burn.git
Nonzero should return an empty vec for zero tensors (#2212)
* Nonzero should return an empty vec for zero tensors * Add nonzero empty test * Add missing import --------- Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
This commit is contained in:
parent
96a23408d2
commit
cc214d366c
|
@ -155,6 +155,8 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
|
|||
.tensor
|
||||
.nonzero_numpy()
|
||||
.into_iter()
|
||||
// As opposed to tch, the resulting vector should be empty for zero tensors
|
||||
.filter_map(|t| if t.numel() > 0 { Some(t) } else { None })
|
||||
.map(TchTensor::new)
|
||||
.collect()
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
||||
TensorData,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use core::{future::Future, ops::Range};
|
||||
|
||||
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
|
||||
|
@ -426,12 +426,18 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
/// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
|
||||
fn bool_nonzero<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
|
||||
async {
|
||||
let indices = B::bool_argwhere(tensor).await;
|
||||
|
||||
if B::int_shape(&indices).num_elements() == 0 {
|
||||
// Return empty vec when all elements are zero
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let dims = B::int_shape(&indices).dims;
|
||||
B::int_chunk(indices, dims[1], 1)
|
||||
.into_iter()
|
||||
|
|
|
@ -91,4 +91,12 @@ mod tests {
|
|||
actual.assert_eq(&data_expected[idx], false)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonzero_empty() {
|
||||
let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
|
||||
let output = tensor.nonzero();
|
||||
|
||||
assert_eq!(output.len(), 0);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue