mirror of https://github.com/tracel-ai/burn.git
Rename `diagonal` to `eye` tensor op and add missing entry for diagonal to Book tensor section (#1449)
* Update tensor.md * Rename diagonal to eye * Remove extra space per PR feedback
This commit is contained in:
parent
093cbd397d
commit
9d4fbc5a35
|
@ -172,6 +172,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
||||||
|
|
||||||
| Burn | PyTorch Equivalent |
|
| Burn | PyTorch Equivalent |
|
||||||
| --------------------------------------------------------------- | ---------------------------------------------- |
|
| --------------------------------------------------------------- | ---------------------------------------------- |
|
||||||
|
| `Tensor::eye(size, device)` | `torch.eye(size, device=device)` |
|
||||||
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
|
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
|
||||||
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
|
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
|
||||||
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
|
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
|
||||||
|
|
|
@ -673,12 +673,12 @@ where
|
||||||
K: Numeric<B>,
|
K: Numeric<B>,
|
||||||
K::Elem: Element,
|
K::Elem: Element,
|
||||||
{
|
{
|
||||||
/// Create diagonal matrix.
|
/// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `size` - The size of the square matrix.
|
/// * `size` - The size of the square matrix.
|
||||||
pub fn diagonal(size: usize, device: &B::Device) -> Self {
|
pub fn eye(size: usize, device: &B::Device) -> Self {
|
||||||
let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze();
|
let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze();
|
||||||
let ones = K::ones([1, size].into(), device);
|
let ones = K::ones([1, size].into(), device);
|
||||||
let zeros = K::zeros([size, size].into(), device);
|
let zeros = K::zeros([size, size].into(), device);
|
||||||
|
|
|
@ -93,7 +93,7 @@ macro_rules! testgen_all {
|
||||||
// test stats
|
// test stats
|
||||||
burn_tensor::testgen_var!();
|
burn_tensor::testgen_var!();
|
||||||
burn_tensor::testgen_cov!();
|
burn_tensor::testgen_cov!();
|
||||||
burn_tensor::testgen_diagonal!();
|
burn_tensor::testgen_eye!();
|
||||||
burn_tensor::testgen_display!();
|
burn_tensor::testgen_display!();
|
||||||
|
|
||||||
// test clone invariance
|
// test clone invariance
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
#[burn_tensor_testgen::testgen(diagonal)]
|
|
||||||
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use burn_tensor::backend::Backend;
|
|
||||||
use burn_tensor::{Data, Tensor};
|
|
||||||
|
|
||||||
type FloatElem = <TestBackend as Backend>::FloatElem;
|
|
||||||
type IntElem = <TestBackend as Backend>::IntElem;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_diagonal() {
|
|
||||||
let device = <TestBackend as Backend>::Device::default();
|
|
||||||
let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]];
|
|
||||||
let lhs = Tensor::<TestBackend, 2>::from_floats(data, &device);
|
|
||||||
let rhs = Tensor::<TestBackend, 2>::diagonal(3, &device);
|
|
||||||
lhs.to_data().assert_approx_eq(&rhs.to_data(), 3);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
#[burn_tensor_testgen::testgen(eye)]
|
||||||
|
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Data, Int, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_eye_float() {
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor = TestTensor::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
|
||||||
|
let rhs = Tensor::<TestBackend, 2>::eye(3, &device);
|
||||||
|
assert_eq!(tensor.to_data(), rhs.to_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_eye_int() {
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor = TestTensorInt::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
|
||||||
|
let rhs = Tensor::<TestBackend, 2, Int>::eye(3, &device);
|
||||||
|
assert_eq!(tensor.to_data(), rhs.to_data());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
mod cov;
|
mod cov;
|
||||||
mod diagonal;
|
|
||||||
mod display;
|
mod display;
|
||||||
|
mod eye;
|
||||||
mod var;
|
mod var;
|
||||||
|
|
Loading…
Reference in New Issue