diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index fc7f4d4b5..d92f428a1 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -155,6 +155,8 @@ impl BoolTensorOps for LibTorch { .tensor .nonzero_numpy() .into_iter() + // As opposed to tch, the resulting vector should be empty for zero tensors + .filter_map(|t| if t.numel() > 0 { Some(t) } else { None }) .map(TchTensor::new) .collect() } diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index b1718ad5c..9671d86c9 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -6,7 +6,7 @@ use crate::{ argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, TensorData, }; -use alloc::vec::Vec; +use alloc::{vec, vec::Vec}; use core::{future::Future, ops::Range}; /// Bool Tensor API for basic operations, see [tensor](crate::Tensor) @@ -426,12 +426,18 @@ pub trait BoolTensorOps { /// # Returns /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of - /// the non-zero elements in that dimension. + /// the non-zero elements in that dimension. If all elements are zero, the vector is empty. fn bool_nonzero( tensor: BoolTensor, ) -> impl Future>> + Send { async { let indices = B::bool_argwhere(tensor).await; + + if B::int_shape(&indices).num_elements() == 0 { + // Return empty vec when all elements are zero + return vec![]; + } + let dims = B::int_shape(&indices).dims; B::int_chunk(indices, dims[1], 1) .into_iter() diff --git a/crates/burn-tensor/src/tests/ops/argwhere_nonzero.rs b/crates/burn-tensor/src/tests/ops/argwhere_nonzero.rs index 3ce229748..cb8a4fb11 100644 --- a/crates/burn-tensor/src/tests/ops/argwhere_nonzero.rs +++ b/crates/burn-tensor/src/tests/ops/argwhere_nonzero.rs @@ -91,4 +91,12 @@ mod tests { actual.assert_eq(&data_expected[idx], false) } } + + #[test] + fn test_nonzero_empty() { + let tensor = TestTensorBool::<1>::from([false, false, false, false, false]); + let output = tensor.nonzero(); + + assert_eq!(output.len(), 0); + } }