Precision option for tensor display (#2139)

This commit is contained in:
Dilshod Tadjibaev 2024-08-08 15:01:42 -05:00 committed by GitHub
parent 27ca6cee95
commit 1c681f46ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 207 additions and 46 deletions

View File

@ -131,41 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------------------------------------------- |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
### Numeric Operations
@ -332,3 +332,78 @@ strategies.
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
## Displaying Tensor Details
Burn provides flexible options for displaying tensor information, allowing you to control the level
of detail and formatting to suit your needs.
### Basic Display
To display a detailed view of a tensor, you can simply use Rust's `println!` or `format!` macros:
```rust
let tensor = Tensor::<Backend, 2>::full([2, 3], 0.123456789, &Default::default());
println!("{}", tensor);
```
This will output:
```
Tensor {
data:
[[0.12345679, 0.12345679, 0.12345679],
[0.12345679, 0.12345679, 0.12345679]],
shape: [2, 3],
device: Cpu,
backend: "ndarray",
kind: "Float",
dtype: "f32",
}
```
### Controlling Precision
You can control the number of decimal places displayed using Rust's formatting syntax:
```rust
println!("{:.2}", tensor);
```
Output:
```
Tensor {
data:
[[0.12, 0.12, 0.12],
[0.12, 0.12, 0.12]],
shape: [2, 3],
device: Cpu,
backend: "ndarray",
kind: "Float",
dtype: "f32",
}
```
### Global Print Options
For more fine-grained control over tensor printing, Burn provides a `PrintOptions` struct and a
`set_print_options` function:
```rust
use burn::tensor::{set_print_options, PrintOptions};
let print_options = PrintOptions {
precision: Some(2),
..Default::default()
};
set_print_options(print_options);
```
Options:
- `precision`: Number of decimal places for floating-point numbers (default: None)
- `threshold`: Maximum number of elements to display before summarizing (default: 1000)
- `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing
(default: 3)

View File

@ -6,7 +6,7 @@ use alloc::format;
use alloc::string::String;
use alloc::vec;
use burn_common::stub::Mutex;
use burn_common::stub::RwLock;
use core::future::Future;
use core::iter::repeat;
use core::{fmt::Debug, ops::Range};
@ -1021,13 +1021,13 @@ where
acc.push(' ');
}
}
fn fmt_inner_tensor(
&self,
acc: &mut String,
depth: usize,
multi_index: &mut [usize],
range: (usize, usize),
precision: Option<usize>,
) {
let (start, end) = range;
for i in start..end {
@ -1043,7 +1043,10 @@ where
if let Some(data) = data {
let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
acc.push_str(&format!("{elem:?}"));
match (precision, K::name()) {
(Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
_ => acc.push_str(&format!("{:?}", elem)),
}
} else {
acc.push_str("<Tensor data not available>");
}
@ -1102,7 +1105,13 @@ where
// if we are at the innermost dimension, just push its elements into the accumulator
if summarize && self.dims()[depth] > 2 * edge_items {
// print the starting `edge_items` elements
self.fmt_inner_tensor(acc, depth, multi_index, (0, edge_items));
self.fmt_inner_tensor(
acc,
depth,
multi_index,
(0, edge_items),
print_options.precision,
);
acc.push_str(", ...");
// print the last `edge_items` elements
self.fmt_inner_tensor(
@ -1110,10 +1119,17 @@ where
depth,
multi_index,
(self.dims()[depth] - edge_items, self.dims()[depth]),
print_options.precision,
);
} else {
// print all the elements
self.fmt_inner_tensor(acc, depth, multi_index, (0, self.dims()[depth]));
self.fmt_inner_tensor(
acc,
depth,
multi_index,
(0, self.dims()[depth]),
print_options.precision,
);
}
} else {
// otherwise, iterate through the current dimension and recursively display the inner tensors
@ -1158,29 +1174,42 @@ where
}
}
#[derive(Clone, Debug)]
/// Options for Tensor pretty printing
pub struct PrintOptions {
/// number of elements to start summarizing tensor
pub threshold: usize,
/// number of starting elements and ending elements to display
pub edge_items: usize,
/// Precision for floating point numbers
pub precision: Option<usize>,
}
static PRINT_OPTS: Mutex<PrintOptions> = Mutex::new(PrintOptions::const_default());
static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
impl PrintOptions {
// We cannot use the default trait as it's not const.
const fn const_default() -> Self {
/// Print options with default values
pub const fn const_default() -> Self {
Self {
threshold: 1000,
edge_items: 3,
precision: None,
}
}
}
impl Default for PrintOptions {
fn default() -> Self {
Self::const_default()
}
}
/// Set print options
pub fn set_print_options(options: PrintOptions) {
*PRINT_OPTS.lock().unwrap() = options
let mut print_opts = PRINT_OPTS.write().unwrap();
*print_opts = options;
}
/// Pretty print tensors
@ -1195,7 +1224,15 @@ where
writeln!(f, "Tensor {{")?;
{
let po = PRINT_OPTS.lock().unwrap();
// Do not lock the mutex for the whole function
let mut po = { PRINT_OPTS.read().unwrap().clone() };
// Override the precision if it is set from the formatter
// This will be possible when the tensor is printed using the `{:.*}` syntax
if let Some(precision) = f.precision() {
po.precision = Some(precision);
}
let mut acc = String::new();
let mut multi_index = vec![0; D];
let summarize = self.shape().num_elements() > po.threshold;

View File

@ -2,7 +2,7 @@
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Shape, Tensor, TensorData};
use burn_tensor::{set_print_options, PrintOptions, Shape, Tensor, TensorData};
type FloatElem = <TestBackend as Backend>::FloatElem;
type IntElem = <TestBackend as Backend>::IntElem;
@ -275,6 +275,55 @@ mod tests {
backend: {:?},
kind: "Float",
dtype: "f32",
}}"#,
tensor.device(),
TestBackend::name(),
);
assert_eq!(output, expected);
}
#[test]
fn test_display_precision() {
let tensor = Tensor::<TestBackend, 2>::full([1, 1], 0.123456789, &Default::default());
let output = format!("{}", tensor);
let expected = format!(
r#"Tensor {{
data:
[[0.12345679]],
shape: [1, 1],
device: {:?},
backend: {:?},
kind: "Float",
dtype: "f32",
}}"#,
tensor.device(),
TestBackend::name(),
);
assert_eq!(output, expected);
// CAN'T DO THIS BECAUSE OF GLOBAL STATE
// let print_options = PrintOptions {
// precision: Some(3),
// ..Default::default()
// };
// set_print_options(print_options);
let tensor = Tensor::<TestBackend, 2>::full([3, 2], 0.123456789, &Default::default());
// Set precision to 3
let output = format!("{:.3}", tensor);
let expected = format!(
r#"Tensor {{
data:
[[0.123, 0.123],
[0.123, 0.123],
[0.123, 0.123]],
shape: [3, 2],
device: {:?},
backend: {:?},
kind: "Float",
dtype: "f32",
}}"#,
tensor.device(),
TestBackend::name(),