Rename `diagonal` to `eye` tensor op and add missing entry for diagonal to Book tensor section (#1449)

* Update tensor.md

* Rename diagonal to eye

* Remove extra space per PR feedback
This commit is contained in:
Dilshod Tadjibaev 2024-03-11 11:00:36 -05:00 committed by GitHub
parent 093cbd397d
commit 9d4fbc5a35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 26 additions and 23 deletions

View File

@ -172,6 +172,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| Burn | PyTorch Equivalent | | Burn | PyTorch Equivalent |
| --------------------------------------------------------------- | ---------------------------------------------- | | --------------------------------------------------------------- | ---------------------------------------------- |
| `Tensor::eye(size, device)` | `torch.eye(size, device=device)` |
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` | | `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` | | `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
| `Tensor::zeros(shape)` | `torch.zeros(shape)` | | `Tensor::zeros(shape)` | `torch.zeros(shape)` |

View File

@ -673,12 +673,12 @@ where
K: Numeric<B>, K: Numeric<B>,
K::Elem: Element, K::Elem: Element,
{ {
/// Create diagonal matrix. /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `size` - The size of the square matrix. /// * `size` - The size of the square matrix.
pub fn diagonal(size: usize, device: &B::Device) -> Self { pub fn eye(size: usize, device: &B::Device) -> Self {
let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze(); let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze();
let ones = K::ones([1, size].into(), device); let ones = K::ones([1, size].into(), device);
let zeros = K::zeros([size, size].into(), device); let zeros = K::zeros([size, size].into(), device);

View File

@ -93,7 +93,7 @@ macro_rules! testgen_all {
// test stats // test stats
burn_tensor::testgen_var!(); burn_tensor::testgen_var!();
burn_tensor::testgen_cov!(); burn_tensor::testgen_cov!();
burn_tensor::testgen_diagonal!(); burn_tensor::testgen_eye!();
burn_tensor::testgen_display!(); burn_tensor::testgen_display!();
// test clone invariance // test clone invariance

View File

@ -1,19 +0,0 @@
#[burn_tensor_testgen::testgen(diagonal)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Tensor};
type FloatElem = <TestBackend as Backend>::FloatElem;
type IntElem = <TestBackend as Backend>::IntElem;
#[test]
fn test_diagonal() {
let device = <TestBackend as Backend>::Device::default();
let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]];
let lhs = Tensor::<TestBackend, 2>::from_floats(data, &device);
let rhs = Tensor::<TestBackend, 2>::diagonal(3, &device);
lhs.to_data().assert_approx_eq(&rhs.to_data(), 3);
}
}

View File

@ -0,0 +1,21 @@
#[burn_tensor_testgen::testgen(eye)]
mod tests {
use super::*;
use burn_tensor::{Data, Int, Tensor};
#[test]
fn test_eye_float() {
let device = Default::default();
let tensor = TestTensor::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
let rhs = Tensor::<TestBackend, 2>::eye(3, &device);
assert_eq!(tensor.to_data(), rhs.to_data());
}
fn test_eye_int() {
let device = Default::default();
let tensor = TestTensorInt::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
let rhs = Tensor::<TestBackend, 2, Int>::eye(3, &device);
assert_eq!(tensor.to_data(), rhs.to_data());
}
}

View File

@ -1,4 +1,4 @@
mod cov; mod cov;
mod diagonal;
mod display; mod display;
mod eye;
mod var; mod var;