mirror of https://github.com/tracel-ai/burn.git
parent
393d86e99d
commit
293020aae6
|
@ -15,6 +15,7 @@ pub use tensor::*;
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate alloc;
|
||||
use super::*;
|
||||
|
||||
pub type TestBackend = CandleBackend<f32, i64>;
|
||||
|
|
|
@ -14,6 +14,8 @@ pub use tensor::*;
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate alloc;
|
||||
|
||||
type TestBackend = crate::TchBackend<f32>;
|
||||
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
|
|
|
@ -38,6 +38,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_clamp!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_create_like!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
|
@ -54,6 +55,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_maxmin!();
|
||||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_one_hot!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_random!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(arange)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -8,4 +9,13 @@ mod tests {
|
|||
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..5);
|
||||
assert_eq!(tensor.into_data(), Data::from([2, 3, 4]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arange_device() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_device(2..5, &device);
|
||||
assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4]));
|
||||
assert_eq!(tensor.device(), device);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(arange_step)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -18,9 +19,27 @@ mod tests {
|
|||
assert_eq!(tensor.into_data(), Data::from([0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arange_step_device() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..9 and the step is 1
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step_device(0..9, 1, &device);
|
||||
assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8]));
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..3 and the step is 2
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step_device(0..3, 2, &device);
|
||||
assert_eq!(tensor.into_data(), Data::from([0, 2]));
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..2 and the step is 5
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step_device(0..2, 5, &device);
|
||||
assert_eq!(tensor.clone().into_data(), Data::from([0]));
|
||||
assert_eq!(tensor.device(), device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_arange_step_panic() {
|
||||
fn should_panic_when_step_is_zero() {
|
||||
// Test that arange_step panics when the step is 0
|
||||
let _tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 0);
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#[burn_tensor_testgen::testgen(cat)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{Bool, Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_cat_ops_2d_dim0() {
|
||||
let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0]]);
|
||||
|
@ -57,4 +57,29 @@ mod tests {
|
|||
let data_expected = Data::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_dimensions_are_not_the_same() {
|
||||
let tensor_1 = TestTensor::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
|
||||
let tensor_2 = TestTensor::from_data([[4.0, 5.0]]);
|
||||
|
||||
TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_list_of_vectors_is_empty() {
|
||||
let tensor: Vec<TestTensor<2>> = vec![];
|
||||
TestTensor::cat(tensor, 0).into_data();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_cat_exceeds_dimension() {
|
||||
let tensor_1 = TestTensor::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]);
|
||||
let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]]);
|
||||
|
||||
TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#[burn_tensor_testgen::testgen(create_like)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_zeros_like() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
|
||||
let data_actual = tensor.zeros_like().into_data();
|
||||
|
||||
let data_expected =
|
||||
Data::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]);
|
||||
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_ones_like() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
|
||||
let data_actual = tensor.ones_like().into_data();
|
||||
|
||||
let data_expected =
|
||||
Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);
|
||||
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_randoms_like() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
|
||||
let data_actual = tensor
|
||||
.random_like(Distribution::Uniform(0.99999, 1.))
|
||||
.into_data();
|
||||
|
||||
let data_expected =
|
||||
Data::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);
|
||||
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -164,4 +164,14 @@ mod tests {
|
|||
Data::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn scatter_should_panic_on_mismatch_of_shapes() {
|
||||
let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]);
|
||||
let values = TestTensor::from_floats([5.0, 4.0]);
|
||||
let indices = TestTensorInt::from_ints([1, 0, 2]);
|
||||
|
||||
tensor.scatter(0, indices, values);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,4 +85,24 @@ mod tests {
|
|||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_inner_dimensions_are_not_equal() {
|
||||
let tensor_1 = TestTensor::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]);
|
||||
let tensor_2 =
|
||||
TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_2);
|
||||
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([
|
||||
[9., 18., 27., 36.],
|
||||
[12., 24., 36., 48.],
|
||||
[15., 30., 45., 60.],
|
||||
[18., 36., 54., 72.]
|
||||
])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ mod cast;
|
|||
mod cat;
|
||||
mod clamp;
|
||||
mod cos;
|
||||
mod create_like;
|
||||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
|
@ -24,6 +25,7 @@ mod matmul;
|
|||
mod maxmin;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod one_hot;
|
||||
mod powf;
|
||||
mod random;
|
||||
mod repeat;
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
#[burn_tensor_testgen::testgen(one_hot)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Int};
|
||||
|
||||
#[test]
|
||||
fn should_support_one_hot() {
|
||||
let tensor = TestTensor::<1>::one_hot(0, 5);
|
||||
assert_eq!(tensor.to_data(), Data::from([1., 0., 0., 0., 0.]));
|
||||
|
||||
let tensor = TestTensor::<1>::one_hot(1, 5);
|
||||
assert_eq!(tensor.to_data(), Data::from([0., 1., 0., 0., 0.]));
|
||||
|
||||
let tensor = TestTensor::<1>::one_hot(4, 5);
|
||||
assert_eq!(tensor.to_data(), Data::from([0., 0., 0., 0., 1.]));
|
||||
|
||||
let tensor = TestTensor::<1>::one_hot(1, 2);
|
||||
assert_eq!(tensor.to_data(), Data::from([0., 1.]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_index_exceeds_number_of_classes() {
|
||||
let tensor = TestTensor::<1>::one_hot(1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_number_of_classes_is_zero() {
|
||||
let tensor = TestTensor::<1>::one_hot(0, 0);
|
||||
}
|
||||
}
|
|
@ -116,4 +116,13 @@ mod tests {
|
|||
Data::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_select_panic_invalid_dimension() {
|
||||
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
|
||||
|
||||
tensor.select(10, indices);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,4 +102,48 @@ mod tests {
|
|||
let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_slice_exceeds_dimension() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
|
||||
|
||||
let data_actual = tensor.slice([0..4]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_slice_with_too_many_dimensions() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
|
||||
|
||||
let data_actual = tensor.slice([0..1, 0..1]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_slice_is_desc() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
|
||||
|
||||
let data_actual = tensor.slice([2..1]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_panic_when_slice_is_equal() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
|
||||
|
||||
let data_actual = tensor.slice([1..1]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,12 +9,47 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_var() {
|
||||
let data = Data::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
|
||||
let data_actual = tensor.var(1).into_data();
|
||||
|
||||
let data_expected = Data::from([[2.4892], [15.3333]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var_mean() {
|
||||
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
|
||||
let (var, mean) = tensor.var_mean(1);
|
||||
|
||||
let var_expected = Data::from([[2.4892], [15.3333]]);
|
||||
let mean_expected = Data::from([[0.125], [1.]]);
|
||||
|
||||
var_expected.assert_approx_eq(&(var.into_data()), 3);
|
||||
mean_expected.assert_approx_eq(&(mean.into_data()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var_bias() {
|
||||
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
|
||||
let data_actual = tensor.var_bias(1).into_data();
|
||||
|
||||
let data_expected = Data::from([[1.86688], [11.5]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var_mean_bias() {
|
||||
let tensor = TestTensor::from_data([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);
|
||||
|
||||
let (var, mean) = tensor.var_mean_bias(1);
|
||||
|
||||
let var_expected = Data::from([[1.86688], [11.5]]);
|
||||
let mean_expected = Data::from([[0.125], [1.]]);
|
||||
|
||||
var_expected.assert_approx_eq(&(var.into_data()), 3);
|
||||
mean_expected.assert_approx_eq(&(mean.into_data()), 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
extern crate alloc;
|
||||
|
||||
mod ops;
|
||||
|
||||
|
|
Loading…
Reference in New Issue