mirror of https://github.com/tracel-ai/burn.git
Nonzero should return an empty vec for zero tensors (#2212)
* Nonzero should return an empty vec for zero tensors * Add nonzero empty test * Add missing import --------- Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
This commit is contained in:
parent
96a23408d2
commit
cc214d366c
|
@ -155,6 +155,8 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
|
||||||
.tensor
|
.tensor
|
||||||
.nonzero_numpy()
|
.nonzero_numpy()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
// As opposed to tch, the resulting vector should be empty for zero tensors
|
||||||
|
.filter_map(|t| if t.numel() > 0 { Some(t) } else { None })
|
||||||
.map(TchTensor::new)
|
.map(TchTensor::new)
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
||||||
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
||||||
TensorData,
|
TensorData,
|
||||||
};
|
};
|
||||||
use alloc::vec::Vec;
|
use alloc::{vec, vec::Vec};
|
||||||
use core::{future::Future, ops::Range};
|
use core::{future::Future, ops::Range};
|
||||||
|
|
||||||
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
|
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
|
||||||
|
@ -426,12 +426,18 @@ pub trait BoolTensorOps<B: Backend> {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||||
/// the non-zero elements in that dimension.
|
/// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
|
||||||
fn bool_nonzero<const D: usize>(
|
fn bool_nonzero<const D: usize>(
|
||||||
tensor: BoolTensor<B, D>,
|
tensor: BoolTensor<B, D>,
|
||||||
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
|
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
|
||||||
async {
|
async {
|
||||||
let indices = B::bool_argwhere(tensor).await;
|
let indices = B::bool_argwhere(tensor).await;
|
||||||
|
|
||||||
|
if B::int_shape(&indices).num_elements() == 0 {
|
||||||
|
// Return empty vec when all elements are zero
|
||||||
|
return vec![];
|
||||||
|
}
|
||||||
|
|
||||||
let dims = B::int_shape(&indices).dims;
|
let dims = B::int_shape(&indices).dims;
|
||||||
B::int_chunk(indices, dims[1], 1)
|
B::int_chunk(indices, dims[1], 1)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
@ -91,4 +91,12 @@ mod tests {
|
||||||
actual.assert_eq(&data_expected[idx], false)
|
actual.assert_eq(&data_expected[idx], false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_nonzero_empty() {
|
||||||
|
let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
|
||||||
|
let output = tensor.nonzero();
|
||||||
|
|
||||||
|
assert_eq!(output.len(), 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue