Perf/tensor ops/tests (#710)

This commit is contained in:
Louis Fortier-Dubois 2023-08-28 12:53:17 -04:00 committed by GitHub
parent f024dc9ccb
commit c89f9969ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 129 additions and 46 deletions

View File

@ -77,8 +77,8 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
B::bool_equal(lhs, rhs)
}
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<B, D>, rhs: bool) -> BoolTensor<B, D> {
B::bool_equal_elem(lhs, rhs)
fn bool_not<const D: usize>(tensor: BoolTensor<B, D>) -> BoolTensor<B, D> {
B::bool_not(tensor)
}
fn bool_into_float<const D: usize>(

View File

@ -10,7 +10,6 @@ mod backend;
mod element;
mod ops;
mod tensor;
pub use backend::*;
pub use tensor::*;
@ -59,6 +58,7 @@ mod tests {
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
burn_tensor::testgen_empty!();
// burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();

View File

@ -96,12 +96,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<CandleBackend<F,
CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap())
}
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<Self, D>, rhs: bool) -> BoolTensor<Self, D> {
let rhs: f64 = match rhs {
false => 0.,
true => 1.,
};
let x = (candle_core::Tensor::ones_like(&lhs.tensor).unwrap() * rhs).unwrap();
CandleTensor::new(lhs.tensor.eq(&x).unwrap())
fn bool_not<const D: usize>(tensor: BoolTensor<Self, D>) -> BoolTensor<Self, D> {
let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap());
CandleTensor::new(tensor.tensor.eq(&x).unwrap())
}
}

View File

@ -106,16 +106,15 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
rhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let mut array = lhs.array;
array.zip_mut_with(&rhs.array, |a, b| *a = *a && *b);
array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b);
NdArrayTensor { array }
}
fn bool_equal_elem<const D: usize>(
lhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
rhs: bool,
fn bool_not<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let array = lhs.array.mapv(|a| a == rhs).into_shared();
let array = tensor.array.mapv(|a| !a).into_shared();
NdArrayTensor { array }
}

View File

@ -97,15 +97,10 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::equal(lhs, rhs)
}
fn bool_equal_elem<const D: usize>(lhs: TchTensor<bool, D>, rhs: bool) -> TchTensor<bool, D> {
let rhs = match rhs {
true => 1,
false => 0,
};
lhs.unary_ops(
|mut tensor| tensor.eq_(rhs).to_kind(tch::Kind::Bool),
|tensor| tensor.eq(rhs),
fn bool_not<const D: usize>(tensor: TchTensor<bool, D>) -> TchTensor<bool, D> {
tensor.unary_ops(
|mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
|tensor| tensor.eq(0),
)
}

View File

@ -23,4 +23,9 @@ where
pub fn float(self) -> Tensor<B, D> {
Tensor::new(B::bool_into_float(self.primitive))
}
/// Inverses boolean values.
pub fn bool_not(self) -> Self {
Tensor::new(B::bool_not(self.primitive))
}
}

View File

@ -227,18 +227,14 @@ pub trait BoolTensorOps<B: Backend> {
rhs: B::BoolTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Equates the tensor with the element.
/// Inverses boolean values.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side element.
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The tensor with the result of the equate.
fn bool_equal_elem<const D: usize>(
lhs: B::BoolTensorPrimitive<D>,
rhs: bool,
) -> B::BoolTensorPrimitive<D>;
/// The tensor with the result of the negation.
fn bool_not<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> B::BoolTensorPrimitive<D>;
}

View File

@ -39,6 +39,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
burn_tensor::testgen_empty!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(cat)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
use burn_tensor::{Bool, Data, Int, Tensor};
#[test]
fn should_support_cat_ops_2d_dim0() {
@ -14,6 +14,28 @@ mod tests {
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn should_support_cat_ops_int() {
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 3]]);
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data([[4, 5, 6]]);
let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data();
let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]);
assert_eq!(&data_actual, &data_expected);
}
#[test]
fn should_support_cat_ops_bool() {
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data([[false, true, true]]);
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data([[true, true, false]]);
let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data();
let data_expected = Data::from([[false, true, true], [true, true, false]]);
assert_eq!(&data_actual, &data_expected);
}
#[test]
fn should_support_cat_ops_2d_dim1() {
let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]);

View File

@ -0,0 +1,26 @@
#[burn_tensor_testgen::testgen(empty)]
mod tests {
use super::*;
use burn_tensor::{Bool, Int, Tensor};
#[test]
fn should_support_float_empty() {
let shape = [2, 2];
let tensor = Tensor::<TestBackend, 2>::empty(shape);
assert_eq!(tensor.shape(), shape.into())
}
#[test]
fn should_support_int_empty() {
let shape = [2, 2];
let tensor = Tensor::<TestBackend, 2, Int>::empty(shape);
assert_eq!(tensor.shape(), shape.into())
}
#[test]
fn should_support_bool_empty() {
let shape = [2, 2];
let tensor = Tensor::<TestBackend, 2, Bool>::empty(shape);
assert_eq!(tensor.shape(), shape.into())
}
}

View File

@ -2,7 +2,7 @@
mod tests {
use super::*;
use burn_tensor::{
backend::Backend, BasicOps, Data, Element, Float, Int, Numeric, Tensor, TensorKind,
backend::Backend, BasicOps, Bool, Data, Element, Float, Int, Numeric, Tensor, TensorKind,
};
type IntElem = <TestBackend as Backend>::IntElem;
@ -45,7 +45,7 @@ mod tests {
#[test]
fn test_int_greater_equal_elem() {
greater_equal_elem::<Float, FloatElem>()
greater_equal_elem::<Int, IntElem>()
}
#[test]
@ -163,8 +163,8 @@ mod tests {
K: Numeric<TestBackend, Elem = E> + BasicOps<TestBackend, Elem = E>,
E: Element,
{
let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
let data_1 = Data::<f32, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert();
let tensor_1 = Tensor::<TestBackend, 2, K>::from_data(data_1);
let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0);
let data_actual_inplace = tensor_1.greater_equal_elem(4.0);
@ -277,4 +277,32 @@ mod tests {
assert_eq!(data_expected, data_actual_cloned.into_data());
assert_eq!(data_expected, data_actual_inplace.into_data());
}
#[test]
fn should_support_bool_equal() {
let data_1 = Data::from([[false, true, true], [true, false, true]]);
let data_2 = Data::from([[false, false, true], [false, true, true]]);
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data(data_1);
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data(data_2);
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
let data_actual_inplace = tensor_1.equal(tensor_2);
let data_expected = Data::from([[true, false, true], [false, false, true]]);
assert_eq!(data_expected, data_actual_cloned.into_data());
assert_eq!(data_expected, data_actual_inplace.into_data());
}
#[test]
fn should_support_bool_not() {
let data_1 = Data::from([[false, true, true], [true, true, false]]);
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data(data_1);
let data_actual_cloned = tensor_1.clone().bool_not();
let data_actual_inplace = tensor_1.bool_not();
let data_expected = Data::from([[true, false, false], [false, false, true]]);
assert_eq!(data_expected, data_actual_cloned.into_data());
assert_eq!(data_expected, data_actual_inplace.into_data());
}
}

View File

@ -9,6 +9,7 @@ mod cat;
mod clamp;
mod cos;
mod div;
mod empty;
mod erf;
mod exp;
mod flatten;

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(reshape)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
use burn_tensor::{Bool, Data, Int, Tensor};
#[test]
fn should_support_reshape_1d() {
@ -13,6 +13,26 @@ mod tests {
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_reshape_int() {
let data = Data::from([0, 1, 2]);
let tensor = Tensor::<TestBackend, 1, Int>::from_data(data);
let data_actual = tensor.clone().reshape([1, 3]).into_data();
let data_expected = Data::from([[0, 1, 2]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_reshape_bool() {
let data = Data::from([false, true, false]);
let tensor = Tensor::<TestBackend, 1, Bool>::from_data(data);
let data_actual = tensor.clone().reshape([1, 3]).into_data();
let data_expected = Data::from([[false, true, false]]);
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_reshape_2d() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);

View File

@ -103,14 +103,8 @@ where
kernel::equal(lhs, rhs)
}
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<Self, D>, rhs: bool) -> BoolTensor<Self, D> {
kernel::equal_elem(
lhs,
match rhs {
true => 1,
false => 0,
},
)
fn bool_not<const D: usize>(tensor: BoolTensor<Self, D>) -> BoolTensor<Self, D> {
kernel::equal_elem(tensor, 0)
}
fn bool_into_float<const D: usize>(tensor: BoolTensor<Self, D>) -> FloatTensor<Self, D> {