Fix equal/not-equal infinity numbers for burn-ndarray (#2166)

This commit is contained in:
Dilshod Tadjibaev 2024-08-15 12:33:54 -05:00 committed by GitHub
parent 31495f72c0
commit d4a1d2026d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 13 deletions

View File

@ -4,7 +4,7 @@ use alloc::vec::Vec;
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
use burn_tensor::ElementConversion;
use core::ops::Range;
use ndarray::IntoDimension;
use ndarray::{IntoDimension, Zip};
// Current crate
use crate::element::{FloatNdArrayElement, QuantElement};
@ -103,10 +103,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E,
lhs: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
rhs: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArray<E> as Backend>::BoolTensorPrimitive<D> {
let mut array = lhs.array;
array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b);
NdArrayTensor { array }
let output = Zip::from(&lhs.array)
.and(&rhs.array)
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
.into_shared();
NdArrayTensor::new(output)
}
fn bool_not<const D: usize>(

View File

@ -8,6 +8,7 @@ use burn_tensor::Distribution;
use burn_tensor::ElementConversion;
use core::ops::Range;
use ndarray::IntoDimension;
use ndarray::Zip;
// Current crate
use crate::element::ExpElement;
@ -109,9 +110,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
lhs: NdArrayTensor<i64, D>,
rhs: NdArrayTensor<i64, D>,
) -> NdArrayTensor<bool, D> {
let tensor = Self::int_sub(lhs, rhs);
Self::int_equal_elem(tensor, 0)
let output = Zip::from(&lhs.array)
.and(&rhs.array)
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
.into_shared();
NdArrayTensor::new(output)
}
fn int_equal_elem<const D: usize>(

View File

@ -1,7 +1,7 @@
// Language
use alloc::vec::Vec;
use core::ops::Range;
use ndarray::IntoDimension;
use ndarray::{IntoDimension, Zip};
// Current crate
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
@ -225,10 +225,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<bool, D> {
let tensor = NdArray::<E>::float_sub(lhs, rhs);
let zero = 0.elem();
Self::float_equal_elem(tensor, zero)
let output = Zip::from(&lhs.array)
.and(&rhs.array)
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
.into_shared();
NdArrayTensor::new(output)
}
fn float_equal_elem<const D: usize>(

View File

@ -128,6 +128,38 @@ mod tests {
lower_equal::<Int, IntElem>()
}
#[test]
fn test_equal_inf() {
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [f32::INFINITY, 4.0, f32::NEG_INFINITY]]);
let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
let data_actual_inplace = tensor_1.equal(tensor_2);
let data_expected = TensorData::from([[false, true, false], [true, false, true]]);
assert_eq!(data_expected, data_actual_cloned.into_data());
assert_eq!(data_expected, data_actual_inplace.into_data());
}
#[test]
fn test_not_equal_inf() {
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, f32::INFINITY, 5.0]]);
let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);
let device = Default::default();
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());
let data_actual_inplace = tensor_1.not_equal(tensor_2);
let data_expected = TensorData::from([[true, false, true], [true, true, true]]);
assert_eq!(data_expected, data_actual_cloned.into_data());
assert_eq!(data_expected, data_actual_inplace.into_data());
}
fn equal<K, E>()
where
K: Numeric<TestBackend, Elem = E> + BasicOps<TestBackend, Elem = E>,