diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 88b81162c..26ac84ed1 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -172,6 +172,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | Burn | PyTorch Equivalent | | --------------------------------------------------------------- | ---------------------------------------------- | +| `Tensor::eye(size, device)` | `torch.eye(size, device=device)` | | `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` | | `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` | | `Tensor::zeros(shape)` | `torch.zeros(shape)` | diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index faecce63a..8964617f9 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -673,12 +673,12 @@ where K: Numeric, K::Elem: Element, { - /// Create diagonal matrix. + /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere. /// /// # Arguments /// /// * `size` - The size of the square matrix. - pub fn diagonal(size: usize, device: &B::Device) -> Self { + pub fn eye(size: usize, device: &B::Device) -> Self { let indices = Tensor::::arange(0..size as i64, device).unsqueeze(); let ones = K::ones([1, size].into(), device); let zeros = K::zeros([size, size].into(), device); diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index ba9c677de..619c21258 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -93,7 +93,7 @@ macro_rules! testgen_all { // test stats burn_tensor::testgen_var!(); burn_tensor::testgen_cov!(); - burn_tensor::testgen_diagonal!(); + burn_tensor::testgen_eye!(); burn_tensor::testgen_display!(); // test clone invariance diff --git a/crates/burn-tensor/src/tests/stats/diagonal.rs b/crates/burn-tensor/src/tests/stats/diagonal.rs deleted file mode 100644 index f60bce494..000000000 --- a/crates/burn-tensor/src/tests/stats/diagonal.rs +++ /dev/null @@ -1,19 +0,0 @@ -#[burn_tensor_testgen::testgen(diagonal)] - -mod tests { - use super::*; - use burn_tensor::backend::Backend; - use burn_tensor::{Data, Tensor}; - - type FloatElem = ::FloatElem; - type IntElem = ::IntElem; - - #[test] - fn test_diagonal() { - let device = ::Device::default(); - let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]; - let lhs = Tensor::::from_floats(data, &device); - let rhs = Tensor::::diagonal(3, &device); - lhs.to_data().assert_approx_eq(&rhs.to_data(), 3); - } -} diff --git a/crates/burn-tensor/src/tests/stats/eye.rs b/crates/burn-tensor/src/tests/stats/eye.rs new file mode 100644 index 000000000..578664418 --- /dev/null +++ b/crates/burn-tensor/src/tests/stats/eye.rs @@ -0,0 +1,21 @@ +#[burn_tensor_testgen::testgen(eye)] + +mod tests { + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_eye_float() { + let device = Default::default(); + let tensor = TestTensor::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); + let rhs = Tensor::::eye(3, &device); + assert_eq!(tensor.to_data(), rhs.to_data()); + } + + fn test_eye_int() { + let device = Default::default(); + let tensor = TestTensorInt::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]); + let rhs = Tensor::::eye(3, &device); + assert_eq!(tensor.to_data(), rhs.to_data()); + } +} diff --git a/crates/burn-tensor/src/tests/stats/mod.rs b/crates/burn-tensor/src/tests/stats/mod.rs index 177bd46c5..79c4cdca4 100644 --- a/crates/burn-tensor/src/tests/stats/mod.rs +++ b/crates/burn-tensor/src/tests/stats/mod.rs @@ -1,4 +1,4 @@ mod cov; -mod diagonal; mod display; +mod eye; mod var;