mirror of https://github.com/tracel-ai/burn.git
Rename `diagonal` to `eye` tensor op and add missing entry for diagonal to Book tensor section (#1449)
* Update tensor.md * Rename diagonal to eye * Remove extra space per PR feedback
This commit is contained in:
parent
093cbd397d
commit
9d4fbc5a35
|
@ -172,6 +172,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
|
||||
| Burn | PyTorch Equivalent |
|
||||
| --------------------------------------------------------------- | ---------------------------------------------- |
|
||||
| `Tensor::eye(size, device)` | `torch.eye(size, device=device)` |
|
||||
| `Tensor::full(shape, fill_value, device)` | `torch.full(shape, fill_value, device=device)` |
|
||||
| `Tensor::ones(shape, device)` | `torch.ones(shape, device=device)` |
|
||||
| `Tensor::zeros(shape)` | `torch.zeros(shape)` |
|
||||
|
|
|
@ -673,12 +673,12 @@ where
|
|||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
/// Create diagonal matrix.
|
||||
/// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `size` - The size of the square matrix.
|
||||
pub fn diagonal(size: usize, device: &B::Device) -> Self {
|
||||
pub fn eye(size: usize, device: &B::Device) -> Self {
|
||||
let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze();
|
||||
let ones = K::ones([1, size].into(), device);
|
||||
let zeros = K::zeros([size, size].into(), device);
|
||||
|
|
|
@ -93,7 +93,7 @@ macro_rules! testgen_all {
|
|||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
burn_tensor::testgen_cov!();
|
||||
burn_tensor::testgen_diagonal!();
|
||||
burn_tensor::testgen_eye!();
|
||||
burn_tensor::testgen_display!();
|
||||
|
||||
// test clone invariance
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
#[burn_tensor_testgen::testgen(diagonal)]
|
||||
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
type FloatElem = <TestBackend as Backend>::FloatElem;
|
||||
type IntElem = <TestBackend as Backend>::IntElem;
|
||||
|
||||
#[test]
|
||||
fn test_diagonal() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let data = [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]];
|
||||
let lhs = Tensor::<TestBackend, 2>::from_floats(data, &device);
|
||||
let rhs = Tensor::<TestBackend, 2>::diagonal(3, &device);
|
||||
lhs.to_data().assert_approx_eq(&rhs.to_data(), 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
#[burn_tensor_testgen::testgen(eye)]
|
||||
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_eye_float() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensor::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
|
||||
let rhs = Tensor::<TestBackend, 2>::eye(3, &device);
|
||||
assert_eq!(tensor.to_data(), rhs.to_data());
|
||||
}
|
||||
|
||||
fn test_eye_int() {
|
||||
let device = Default::default();
|
||||
let tensor = TestTensorInt::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
|
||||
let rhs = Tensor::<TestBackend, 2, Int>::eye(3, &device);
|
||||
assert_eq!(tensor.to_data(), rhs.to_data());
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
mod cov;
|
||||
mod diagonal;
|
||||
mod display;
|
||||
mod eye;
|
||||
mod var;
|
||||
|
|
Loading…
Reference in New Issue