mirror of https://github.com/tracel-ai/burn.git
Add is_nan and contains_nan tensor ops (#2088)
* Add is_nan and contains_nan tensor ops * Enable nan test for burn-candle * Disabling tests due to #2089
This commit is contained in:
parent
27d42cdaad
commit
cd848b1c94
|
@ -190,6 +190,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
|
||||
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
|
||||
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
|
||||
| `tensor.contains_nan()` | N/A |
|
||||
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
|
||||
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
|
||||
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
|
||||
|
@ -199,6 +200,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
|
||||
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
|
||||
| `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` |
|
||||
| `tensor.is_nan()` | `torch.isnan(tensor)` |
|
||||
| `tensor.lower(other)` | `tensor.lt(other)` |
|
||||
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
|
||||
| `tensor.lower_equal(other)` | `tensor.le(other)` |
|
||||
|
@ -304,10 +306,11 @@ Those operations are only available for `Bool` tensors.
|
|||
|
||||
### Quantization Operations
|
||||
|
||||
Those operations are only available for `Float` tensors on backends that implement quantization strategies.
|
||||
Those operations are only available for `Float` tensors on backends that implement quantization
|
||||
strategies.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------------------------ | ------------------------------- |
|
||||
| ---------------------------------- | ------------------ |
|
||||
| `tensor.quantize(scheme, qparams)` | N/A |
|
||||
| `tensor.dequantize()` | N/A |
|
||||
|
||||
|
|
|
@ -87,6 +87,7 @@ mod tests {
|
|||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
burn_tensor::testgen_nan!();
|
||||
|
||||
// TODO: https://github.com/tracel-ai/burn/issues/1237
|
||||
//
|
||||
|
|
|
@ -778,6 +778,32 @@ where
|
|||
// Assign the original tensor data to the appropriate slice of the padded tensor
|
||||
padded_tensor.slice_assign(ranges, self)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
|
||||
pub fn is_nan(&self) -> Tensor<B, D, Bool> {
|
||||
// Check if the input tensor is NaN by comparing it to itself
|
||||
// NaN is the only value that is not equal to itself
|
||||
K::not_equal(self.primitive.clone(), self.primitive.clone())
|
||||
}
|
||||
|
||||
/// Checks if the tensor contains any NaN values.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
|
||||
pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
|
||||
// Summing the tensor will result in NaN if the tensor contains any NaN values
|
||||
// This is faster than checking each element individually
|
||||
// because it rolls up the NaN values into a single value
|
||||
let sum = K::sum(self.primitive.clone());
|
||||
|
||||
// Check if the sum is NaN by comparing it to itself
|
||||
K::not_equal(sum.clone(), sum)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, K> Tensor<B, 2, K>
|
||||
|
|
|
@ -103,6 +103,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_topk!();
|
||||
burn_tensor::testgen_remainder!();
|
||||
burn_tensor::testgen_cartesian_grid!();
|
||||
burn_tensor::testgen_nan!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
|
|
|
@ -34,6 +34,7 @@ mod matmul;
|
|||
mod maxmin;
|
||||
mod movedim;
|
||||
mod mul;
|
||||
mod nan;
|
||||
mod narrow;
|
||||
mod neg;
|
||||
mod one_hot;
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
#[burn_tensor_testgen::testgen(nan)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Int, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
|
||||
fn is_nan() {
|
||||
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let no_nan_expected =
|
||||
TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);
|
||||
|
||||
let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [f32::NAN, 4.0, 5.0]]);
|
||||
let with_nan_expected =
|
||||
TestTensorBool::<2>::from([[false, true, false], [true, false, false]]);
|
||||
|
||||
assert_eq!(no_nan_expected.into_data(), no_nan.is_nan().into_data());
|
||||
|
||||
assert_eq!(with_nan_expected.into_data(), with_nan.is_nan().into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
|
||||
fn contains_nan() {
|
||||
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert!(!no_nan.contains_nan().into_scalar());
|
||||
|
||||
let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert!(with_nan.contains_nan().into_scalar());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue