mirror of https://github.com/tracel-ai/burn.git
Pretty Print Tensors (#257)
This commit is contained in:
parent
ca8ee0724d
commit
d8f64ce1dd
|
@ -28,7 +28,7 @@ mod tests {
|
|||
type TestBackend = crate::NdArrayBackend<f32>;
|
||||
type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
|
||||
use alloc::format;
|
||||
burn_tensor::testgen_all!();
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
|
||||
use crate::{backend::Backend, Bool, Data, Float, Int, Shape, TensorKind};
|
||||
|
||||
|
@ -266,13 +269,93 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<B, const D: usize, K> Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
K: BasicOps<B>,
|
||||
<K as BasicOps<B>>::Elem: Debug,
|
||||
{
|
||||
/// Recursively formats the tensor data for display and appends it to the provided accumulator string.
|
||||
///
|
||||
/// This function is designed to work with tensors of any dimensionality.
|
||||
/// It traverses the tensor dimensions recursively, converting the elements
|
||||
/// to strings and appending them to the accumulator string with the
|
||||
/// appropriate formatting.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
|
||||
/// * `depth` - The current depth of the tensor dimensions being processed.
|
||||
/// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
|
||||
fn display_recursive(&self, acc: &mut String, depth: usize, multi_index: &mut [usize]) {
|
||||
if depth == 0 {
|
||||
acc.push('[');
|
||||
}
|
||||
|
||||
if depth == self.dims().len() - 1 {
|
||||
// if we are at the innermost dimension, just push its elements into the accumulator
|
||||
for i in 0..self.dims()[depth] {
|
||||
if i > 0 {
|
||||
acc.push_str(", ");
|
||||
}
|
||||
multi_index[depth] = i;
|
||||
let range: [core::ops::Range<usize>; D] =
|
||||
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
|
||||
let elem = &self.clone().index(range).to_data().value[0];
|
||||
acc.push_str(&format!("{:?}", elem));
|
||||
}
|
||||
} else {
|
||||
// otherwise, iterate through the current dimension and recursively display the inner tensors
|
||||
for i in 0..self.dims()[depth] {
|
||||
if i > 0 {
|
||||
acc.push_str(", ");
|
||||
}
|
||||
acc.push('[');
|
||||
multi_index[depth] = i;
|
||||
self.display_recursive(acc, depth + 1, multi_index);
|
||||
acc.push(']');
|
||||
}
|
||||
}
|
||||
|
||||
if depth == 0 {
|
||||
acc.push(']');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pretty print tensors
|
||||
impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
B::IntElem: core::fmt::Display,
|
||||
K: BasicOps<B>,
|
||||
<K as BasicOps<B>>::Elem: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
writeln!(f, "Tensor {{")?;
|
||||
write!(f, " data: ")?;
|
||||
|
||||
let mut acc = String::new();
|
||||
let mut multi_index = vec![0; D];
|
||||
self.display_recursive(&mut acc, 0, &mut multi_index);
|
||||
write!(f, "{}", acc)?;
|
||||
writeln!(f, ",")?;
|
||||
writeln!(f, " shape: {:?},", self.dims())?;
|
||||
writeln!(f, " device: {:?},", self.device())?;
|
||||
writeln!(f, " backend: {:?},", B::name())?;
|
||||
writeln!(f, " kind: {:?},", K::name())?;
|
||||
writeln!(f, " dtype: {:?},", K::elem_type_name())?;
|
||||
write!(f, "}}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that list all operations that can be applied on all tensors.
|
||||
///
|
||||
/// # Warnings
|
||||
///
|
||||
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
|
||||
pub trait BasicOps<B: Backend>: TensorKind<B> {
|
||||
type Elem;
|
||||
type Elem: 'static;
|
||||
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
|
||||
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D>;
|
||||
|
@ -310,6 +393,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
||||
fn elem_type_name() -> &'static str {
|
||||
core::any::type_name::<Self::Elem>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BasicOps<B> for Float {
|
||||
|
|
|
@ -9,16 +9,26 @@ pub struct Bool;
|
|||
|
||||
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
|
||||
type Primitive<const D: usize>: Clone + core::fmt::Debug;
|
||||
fn name() -> &'static str;
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorKind<B> for Float {
|
||||
type Primitive<const D: usize> = B::TensorPrimitive<D>;
|
||||
fn name() -> &'static str {
|
||||
"Float"
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorKind<B> for Int {
|
||||
type Primitive<const D: usize> = B::IntTensorPrimitive<D>;
|
||||
fn name() -> &'static str {
|
||||
"Int"
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> TensorKind<B> for Bool {
|
||||
type Primitive<const D: usize> = B::BoolTensorPrimitive<D>;
|
||||
fn name() -> &'static str {
|
||||
"Bool"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(stats)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -13,4 +14,84 @@ mod tests {
|
|||
let data_expected = Data::from([[2.4892], [15.3333]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_2d_int_tensor() {
|
||||
let int_data = Data::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
|
||||
let tensor_int: burn_tensor::Tensor<TestBackend, 2, burn_tensor::Int> =
|
||||
Tensor::from_data(int_data);
|
||||
|
||||
let output = format!("{}", tensor_int);
|
||||
let expected = format!(
|
||||
"Tensor {{\n data: [[1, 2, 3], [4, 5, 6], [7, 8, 9]],\n shape: [3, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Int\",\n dtype: \"i64\",\n}}",
|
||||
TestBackend::name()
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_2d_float_tensor() {
|
||||
let float_data = Data::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]);
|
||||
let tensor_float: burn_tensor::Tensor<TestBackend, 2, burn_tensor::Float> =
|
||||
Tensor::from_data(float_data);
|
||||
|
||||
let output = format!("{}", tensor_float);
|
||||
let expected = format!(
|
||||
"Tensor {{\n data: [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]],\n shape: [3, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Float\",\n dtype: \"f32\",\n}}",
|
||||
TestBackend::name()
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_2d_bool_tensor() {
|
||||
let bool_data = Data::from([
|
||||
[true, false, true],
|
||||
[false, true, false],
|
||||
[false, true, true],
|
||||
]);
|
||||
let tensor_bool: burn_tensor::Tensor<TestBackend, 2, burn_tensor::Bool> =
|
||||
Tensor::from_data(bool_data);
|
||||
|
||||
let output = format!("{}", tensor_bool);
|
||||
let expected = format!(
|
||||
"Tensor {{\n data: [[true, false, true], [false, true, false], [false, true, true]],\n shape: [3, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Bool\",\n dtype: \"bool\",\n}}",
|
||||
TestBackend::name()
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_3d_tensor() {
|
||||
let data = Data::from([
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],
|
||||
]);
|
||||
let tensor: burn_tensor::Tensor<TestBackend, 3, burn_tensor::Int> = Tensor::from_data(data);
|
||||
|
||||
let output = format!("{}", tensor);
|
||||
let expected = format!(
|
||||
"Tensor {{\n data: [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], \
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]],\n shape: [2, 3, 4],\n device: Cpu,\n backend: \"{}\",\n kind: \"Int\",\n dtype: \"i64\",\n}}",
|
||||
TestBackend::name()
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display_4d_tensor() {
|
||||
let data = Data::from([
|
||||
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
|
||||
[[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]],
|
||||
]);
|
||||
|
||||
let tensor: burn_tensor::Tensor<TestBackend, 4, burn_tensor::Int> = Tensor::from_data(data);
|
||||
|
||||
let output = format!("{}", tensor);
|
||||
let expected = format!(
|
||||
"Tensor {{\n data: [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]],\n shape: [2, 2, 2, 3],\n device: Cpu,\n backend: \"{}\",\n kind: \"Int\",\n dtype: \"i64\",\n}}",
|
||||
TestBackend::name()
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue