mirror of https://github.com/tracel-ai/burn.git
Perf/tensor ops/tests (#710)
This commit is contained in:
parent
f024dc9ccb
commit
c89f9969ed
|
@ -77,8 +77,8 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
|
|||
B::bool_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<B, D>, rhs: bool) -> BoolTensor<B, D> {
|
||||
B::bool_equal_elem(lhs, rhs)
|
||||
fn bool_not<const D: usize>(tensor: BoolTensor<B, D>) -> BoolTensor<B, D> {
|
||||
B::bool_not(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_float<const D: usize>(
|
||||
|
|
|
@ -10,7 +10,6 @@ mod backend;
|
|||
mod element;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use tensor::*;
|
||||
|
||||
|
@ -59,6 +58,7 @@ mod tests {
|
|||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_empty!();
|
||||
// burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_flatten!();
|
||||
|
|
|
@ -96,12 +96,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<CandleBackend<F,
|
|||
CandleTensor::new(lhs.tensor.eq(&rhs.tensor).unwrap())
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<Self, D>, rhs: bool) -> BoolTensor<Self, D> {
|
||||
let rhs: f64 = match rhs {
|
||||
false => 0.,
|
||||
true => 1.,
|
||||
};
|
||||
let x = (candle_core::Tensor::ones_like(&lhs.tensor).unwrap() * rhs).unwrap();
|
||||
CandleTensor::new(lhs.tensor.eq(&x).unwrap())
|
||||
fn bool_not<const D: usize>(tensor: BoolTensor<Self, D>) -> BoolTensor<Self, D> {
|
||||
let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap());
|
||||
CandleTensor::new(tensor.tensor.eq(&x).unwrap())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -106,16 +106,15 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
|
|||
rhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
let mut array = lhs.array;
|
||||
array.zip_mut_with(&rhs.array, |a, b| *a = *a && *b);
|
||||
array.zip_mut_with(&rhs.array, |a, b| *a = *a == *b);
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(
|
||||
lhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
rhs: bool,
|
||||
fn bool_not<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
let array = lhs.array.mapv(|a| a == rhs).into_shared();
|
||||
let array = tensor.array.mapv(|a| !a).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
|
|
|
@ -97,15 +97,10 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(lhs: TchTensor<bool, D>, rhs: bool) -> TchTensor<bool, D> {
|
||||
let rhs = match rhs {
|
||||
true => 1,
|
||||
false => 0,
|
||||
};
|
||||
|
||||
lhs.unary_ops(
|
||||
|mut tensor| tensor.eq_(rhs).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.eq(rhs),
|
||||
fn bool_not<const D: usize>(tensor: TchTensor<bool, D>) -> TchTensor<bool, D> {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.eq(0),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -23,4 +23,9 @@ where
|
|||
pub fn float(self) -> Tensor<B, D> {
|
||||
Tensor::new(B::bool_into_float(self.primitive))
|
||||
}
|
||||
|
||||
/// Inverses boolean values.
|
||||
pub fn bool_not(self) -> Self {
|
||||
Tensor::new(B::bool_not(self.primitive))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -227,18 +227,14 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
rhs: B::BoolTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Equates the tensor with the element.
|
||||
/// Inverses boolean values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side element.
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the result of the equate.
|
||||
fn bool_equal_elem<const D: usize>(
|
||||
lhs: B::BoolTensorPrimitive<D>,
|
||||
rhs: bool,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
/// The tensor with the result of the negation.
|
||||
fn bool_not<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> B::BoolTensorPrimitive<D>;
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_empty!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_flatten!();
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(cat)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
use burn_tensor::{Bool, Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_2d_dim0() {
|
||||
|
@ -14,6 +14,28 @@ mod tests {
|
|||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_int() {
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 3]]);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Int>::from_data([[4, 5, 6]]);
|
||||
|
||||
let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data();
|
||||
|
||||
let data_expected = Data::from([[1, 2, 3], [4, 5, 6]]);
|
||||
assert_eq!(&data_actual, &data_expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_bool() {
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data([[false, true, true]]);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data([[true, true, false]]);
|
||||
|
||||
let data_actual = Tensor::cat(vec![tensor_1, tensor_2], 0).into_data();
|
||||
|
||||
let data_expected = Data::from([[false, true, true], [true, true, false]]);
|
||||
assert_eq!(&data_actual, &data_expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_2d_dim1() {
|
||||
let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]);
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
#[burn_tensor_testgen::testgen(empty)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Bool, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_float_empty() {
|
||||
let shape = [2, 2];
|
||||
let tensor = Tensor::<TestBackend, 2>::empty(shape);
|
||||
assert_eq!(tensor.shape(), shape.into())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_int_empty() {
|
||||
let shape = [2, 2];
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::empty(shape);
|
||||
assert_eq!(tensor.shape(), shape.into())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_empty() {
|
||||
let shape = [2, 2];
|
||||
let tensor = Tensor::<TestBackend, 2, Bool>::empty(shape);
|
||||
assert_eq!(tensor.shape(), shape.into())
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{
|
||||
backend::Backend, BasicOps, Data, Element, Float, Int, Numeric, Tensor, TensorKind,
|
||||
backend::Backend, BasicOps, Bool, Data, Element, Float, Int, Numeric, Tensor, TensorKind,
|
||||
};
|
||||
|
||||
type IntElem = <TestBackend as Backend>::IntElem;
|
||||
|
@ -45,7 +45,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_int_greater_equal_elem() {
|
||||
greater_equal_elem::<Float, FloatElem>()
|
||||
greater_equal_elem::<Int, IntElem>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -163,8 +163,8 @@ mod tests {
|
|||
K: Numeric<TestBackend, Elem = E> + BasicOps<TestBackend, Elem = E>,
|
||||
E: Element,
|
||||
{
|
||||
let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
|
||||
let data_1 = Data::<f32, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]).convert();
|
||||
let tensor_1 = Tensor::<TestBackend, 2, K>::from_data(data_1);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0);
|
||||
let data_actual_inplace = tensor_1.greater_equal_elem(4.0);
|
||||
|
@ -277,4 +277,32 @@ mod tests {
|
|||
assert_eq!(data_expected, data_actual_cloned.into_data());
|
||||
assert_eq!(data_expected, data_actual_inplace.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_equal() {
|
||||
let data_1 = Data::from([[false, true, true], [true, false, true]]);
|
||||
let data_2 = Data::from([[false, false, true], [false, true, true]]);
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data(data_1);
|
||||
let tensor_2 = Tensor::<TestBackend, 2, Bool>::from_data(data_2);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());
|
||||
let data_actual_inplace = tensor_1.equal(tensor_2);
|
||||
|
||||
let data_expected = Data::from([[true, false, true], [false, false, true]]);
|
||||
assert_eq!(data_expected, data_actual_cloned.into_data());
|
||||
assert_eq!(data_expected, data_actual_inplace.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_not() {
|
||||
let data_1 = Data::from([[false, true, true], [true, true, false]]);
|
||||
let tensor_1 = Tensor::<TestBackend, 2, Bool>::from_data(data_1);
|
||||
|
||||
let data_actual_cloned = tensor_1.clone().bool_not();
|
||||
let data_actual_inplace = tensor_1.bool_not();
|
||||
|
||||
let data_expected = Data::from([[true, false, false], [false, false, true]]);
|
||||
assert_eq!(data_expected, data_actual_cloned.into_data());
|
||||
assert_eq!(data_expected, data_actual_inplace.into_data());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ mod cat;
|
|||
mod clamp;
|
||||
mod cos;
|
||||
mod div;
|
||||
mod empty;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod flatten;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(reshape)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
use burn_tensor::{Bool, Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_reshape_1d() {
|
||||
|
@ -13,6 +13,26 @@ mod tests {
|
|||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_reshape_int() {
|
||||
let data = Data::from([0, 1, 2]);
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::from_data(data);
|
||||
|
||||
let data_actual = tensor.clone().reshape([1, 3]).into_data();
|
||||
let data_expected = Data::from([[0, 1, 2]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_reshape_bool() {
|
||||
let data = Data::from([false, true, false]);
|
||||
let tensor = Tensor::<TestBackend, 1, Bool>::from_data(data);
|
||||
|
||||
let data_actual = tensor.clone().reshape([1, 3]).into_data();
|
||||
let data_expected = Data::from([[false, true, false]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_reshape_2d() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
|
|
@ -103,14 +103,8 @@ where
|
|||
kernel::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<Self, D>, rhs: bool) -> BoolTensor<Self, D> {
|
||||
kernel::equal_elem(
|
||||
lhs,
|
||||
match rhs {
|
||||
true => 1,
|
||||
false => 0,
|
||||
},
|
||||
)
|
||||
fn bool_not<const D: usize>(tensor: BoolTensor<Self, D>) -> BoolTensor<Self, D> {
|
||||
kernel::equal_elem(tensor, 0)
|
||||
}
|
||||
|
||||
fn bool_into_float<const D: usize>(tensor: BoolTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
|
|
Loading…
Reference in New Issue