Add is_nan and contains_nan tensor ops (#2088)

* Add is_nan and contains_nan tensor ops

* Enable nan test for burn-candle

* Disabling tests due to #2089
This commit is contained in:
Dilshod Tadjibaev 2024-08-06 12:16:12 -05:00 committed by GitHub
parent 27d42cdaad
commit cd848b1c94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 68 additions and 5 deletions

View File

@ -190,6 +190,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.contains_nan()` | N/A |
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
@ -199,6 +200,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
| `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` |
| `tensor.is_nan()` | `torch.isnan(tensor)` |
| `tensor.lower(other)` | `tensor.lt(other)` |
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
| `tensor.lower_equal(other)` | `tensor.le(other)` |
@ -304,10 +306,11 @@ Those operations are only available for `Bool` tensors.
### Quantization Operations
Those operations are only available for `Float` tensors on backends that implement quantization strategies.
Those operations are only available for `Float` tensors on backends that implement quantization
strategies.
| Burn API | PyTorch Equivalent |
| ------------------------------------ | ------------------------------- |
| ---------------------------------- | ------------------ |
| `tensor.quantize(scheme, qparams)` | N/A |
| `tensor.dequantize()` | N/A |

View File

@ -87,6 +87,7 @@ mod tests {
burn_tensor::testgen_flip!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();
burn_tensor::testgen_nan!();
// TODO: https://github.com/tracel-ai/burn/issues/1237
//

View File

@ -778,6 +778,32 @@ where
// Assign the original tensor data to the appropriate slice of the padded tensor
padded_tensor.slice_assign(ranges, self)
}
/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
///
/// # Returns
///
/// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
pub fn is_nan(&self) -> Tensor<B, D, Bool> {
// Check if the input tensor is NaN by comparing it to itself
// NaN is the only value that is not equal to itself
K::not_equal(self.primitive.clone(), self.primitive.clone())
}
/// Checks if the tensor contains any NaN values.
///
/// # Returns
///
/// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
// Summing the tensor will result in NaN if the tensor contains any NaN values
// This is faster than checking each element individually
// because it rolls up the NaN values into a single value
let sum = K::sum(self.primitive.clone());
// Check if the sum is NaN by comparing it to itself
K::not_equal(sum.clone(), sum)
}
}
impl<B, K> Tensor<B, 2, K>

View File

@ -103,6 +103,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_topk!();
burn_tensor::testgen_remainder!();
burn_tensor::testgen_cartesian_grid!();
burn_tensor::testgen_nan!();
// test stats
burn_tensor::testgen_var!();

View File

@ -34,6 +34,7 @@ mod matmul;
mod maxmin;
mod movedim;
mod mul;
mod nan;
mod narrow;
mod neg;
mod one_hot;

View File

@ -0,0 +1,31 @@
#[burn_tensor_testgen::testgen(nan)]
mod tests {
use super::*;
use burn_tensor::{Int, Tensor, TensorData};
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
fn is_nan() {
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let no_nan_expected =
TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);
let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [f32::NAN, 4.0, 5.0]]);
let with_nan_expected =
TestTensorBool::<2>::from([[false, true, false], [true, false, false]]);
assert_eq!(no_nan_expected.into_data(), no_nan.is_nan().into_data());
assert_eq!(with_nan_expected.into_data(), with_nan.is_nan().into_data());
}
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
fn contains_nan() {
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert!(!no_nan.contains_nan().into_scalar());
let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [3.0, 4.0, 5.0]]);
assert!(with_nan.contains_nan().into_scalar());
}
}