mirror of https://github.com/tracel-ai/burn.git
Fix equal/not-equal infinity numbers for burn-ndarray (#2166)
This commit is contained in:
parent
31495f72c0
commit
d4a1d2026d
|
@ -4,7 +4,7 @@ use alloc::vec::Vec;
|
|||
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
|
||||
use burn_tensor::ElementConversion;
|
||||
use core::ops::Range;
|
||||
use ndarray::IntoDimension;
|
||||
use ndarray::{IntoDimension, Zip};
|
||||
|
||||
// Current crate
|
||||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
|
@ -103,10 +103,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E,
|
|||
lhs: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
rhs: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <NdArray<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
let mut array = lhs.array;
|
||||
array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b);
|
||||
|
||||
NdArrayTensor { array }
|
||||
let output = Zip::from(&lhs.array)
|
||||
.and(&rhs.array)
|
||||
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
|
||||
.into_shared();
|
||||
NdArrayTensor::new(output)
|
||||
}
|
||||
|
||||
fn bool_not<const D: usize>(
|
||||
|
|
|
@ -8,6 +8,7 @@ use burn_tensor::Distribution;
|
|||
use burn_tensor::ElementConversion;
|
||||
use core::ops::Range;
|
||||
use ndarray::IntoDimension;
|
||||
use ndarray::Zip;
|
||||
|
||||
// Current crate
|
||||
use crate::element::ExpElement;
|
||||
|
@ -109,9 +110,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
|
|||
lhs: NdArrayTensor<i64, D>,
|
||||
rhs: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
let tensor = Self::int_sub(lhs, rhs);
|
||||
|
||||
Self::int_equal_elem(tensor, 0)
|
||||
let output = Zip::from(&lhs.array)
|
||||
.and(&rhs.array)
|
||||
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
|
||||
.into_shared();
|
||||
NdArrayTensor::new(output)
|
||||
}
|
||||
|
||||
fn int_equal_elem<const D: usize>(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Language
|
||||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
use ndarray::IntoDimension;
|
||||
use ndarray::{IntoDimension, Zip};
|
||||
|
||||
// Current crate
|
||||
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
|
||||
|
@ -225,10 +225,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
lhs: NdArrayTensor<E, D>,
|
||||
rhs: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
let tensor = NdArray::<E>::float_sub(lhs, rhs);
|
||||
let zero = 0.elem();
|
||||
|
||||
Self::float_equal_elem(tensor, zero)
|
||||
let output = Zip::from(&lhs.array)
|
||||
.and(&rhs.array)
|
||||
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
|
||||
.into_shared();
|
||||
NdArrayTensor::new(output)
|
||||
}
|
||||
|
||||
fn float_equal_elem<const D: usize>(
|
||||
|
|
|
@ -128,6 +128,38 @@ mod tests {
|
|||
lower_equal::<Int, IntElem>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_equal_inf() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [f32::INFINITY, 4.0, f32::NEG_INFINITY]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[false, true, false], [true, false, true]]);
|
||||
assert_eq!(data_expected, data_actual_cloned.into_data());
|
||||
assert_eq!(data_expected, data_actual_inplace.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_equal_inf() {
|
||||
let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, f32::INFINITY, 5.0]]);
|
||||
let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1, &device);
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2, &device);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.not_equal(tensor_2);
|
||||
|
||||
let data_expected = TensorData::from([[true, false, true], [true, true, true]]);
|
||||
assert_eq!(data_expected, data_actual_cloned.into_data());
|
||||
assert_eq!(data_expected, data_actual_inplace.into_data());
|
||||
}
|
||||
|
||||
fn equal<K, E>()
|
||||
where
|
||||
K: Numeric<TestBackend, Elem = E> + BasicOps<TestBackend, Elem = E>,
|
||||
|
|
Loading…
Reference in New Issue