mirror of https://github.com/tracel-ai/burn.git
Precision option for tensor display (#2139)
This commit is contained in:
parent
27ca6cee95
commit
1c681f46ec
|
@ -131,41 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
|
|||
|
||||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
| ------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
|
||||
| `tensor.all()` | `tensor.all()` |
|
||||
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
||||
| `tensor.any()` | `tensor.any()` |
|
||||
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.equal(other)` | `x == y` |
|
||||
| `tensor.expand(shape)` | `tensor.expand(shape)` |
|
||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.flip(axes)` | `tensor.flip(axes)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `tensor.into_scalar()` | `tensor.item()` |
|
||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
||||
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
|
||||
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
|
||||
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||
| `tensor.to_data()` | N/A |
|
||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
||||
| Burn | PyTorch Equivalent |
|
||||
| ------------------------------------- | ------------------------------------------------------------------------- |
|
||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
|
||||
| `tensor.all()` | `tensor.all()` |
|
||||
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
||||
| `tensor.any()` | `tensor.any()` |
|
||||
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||
| `tensor.device()` | `tensor.device` |
|
||||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.equal(other)` | `x == y` |
|
||||
| `tensor.expand(shape)` | `tensor.expand(shape)` |
|
||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.flip(axes)` | `tensor.flip(axes)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `tensor.into_scalar()` | `tensor.item()` |
|
||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
||||
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
|
||||
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
|
||||
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||
| `tensor.to_data()` | N/A |
|
||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
||||
|
||||
### Numeric Operations
|
||||
|
||||
|
@ -332,3 +332,78 @@ strategies.
|
|||
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
|
||||
|
||||
## Displaying Tensor Details
|
||||
|
||||
Burn provides flexible options for displaying tensor information, allowing you to control the level
|
||||
of detail and formatting to suit your needs.
|
||||
|
||||
### Basic Display
|
||||
|
||||
To display a detailed view of a tensor, you can simply use Rust's `println!` or `format!` macros:
|
||||
|
||||
```rust
|
||||
let tensor = Tensor::<Backend, 2>::full([2, 3], 0.123456789, &Default::default());
|
||||
println!("{}", tensor);
|
||||
```
|
||||
|
||||
This will output:
|
||||
|
||||
```
|
||||
Tensor {
|
||||
data:
|
||||
[[0.12345679, 0.12345679, 0.12345679],
|
||||
[0.12345679, 0.12345679, 0.12345679]],
|
||||
shape: [2, 3],
|
||||
device: Cpu,
|
||||
backend: "ndarray",
|
||||
kind: "Float",
|
||||
dtype: "f32",
|
||||
}
|
||||
```
|
||||
|
||||
### Controlling Precision
|
||||
|
||||
You can control the number of decimal places displayed using Rust's formatting syntax:
|
||||
|
||||
```rust
|
||||
println!("{:.2}", tensor);
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
Tensor {
|
||||
data:
|
||||
[[0.12, 0.12, 0.12],
|
||||
[0.12, 0.12, 0.12]],
|
||||
shape: [2, 3],
|
||||
device: Cpu,
|
||||
backend: "ndarray",
|
||||
kind: "Float",
|
||||
dtype: "f32",
|
||||
}
|
||||
```
|
||||
|
||||
### Global Print Options
|
||||
|
||||
For more fine-grained control over tensor printing, Burn provides a `PrintOptions` struct and a
|
||||
`set_print_options` function:
|
||||
|
||||
```rust
|
||||
use burn::tensor::{set_print_options, PrintOptions};
|
||||
|
||||
let print_options = PrintOptions {
|
||||
precision: Some(2),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
set_print_options(print_options);
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `precision`: Number of decimal places for floating-point numbers (default: None)
|
||||
- `threshold`: Maximum number of elements to display before summarizing (default: 1000)
|
||||
- `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing
|
||||
(default: 3)
|
||||
|
|
|
@ -6,7 +6,7 @@ use alloc::format;
|
|||
use alloc::string::String;
|
||||
use alloc::vec;
|
||||
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_common::stub::RwLock;
|
||||
use core::future::Future;
|
||||
use core::iter::repeat;
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
|
@ -1021,13 +1021,13 @@ where
|
|||
acc.push(' ');
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_inner_tensor(
|
||||
&self,
|
||||
acc: &mut String,
|
||||
depth: usize,
|
||||
multi_index: &mut [usize],
|
||||
range: (usize, usize),
|
||||
precision: Option<usize>,
|
||||
) {
|
||||
let (start, end) = range;
|
||||
for i in start..end {
|
||||
|
@ -1043,7 +1043,10 @@ where
|
|||
|
||||
if let Some(data) = data {
|
||||
let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
|
||||
acc.push_str(&format!("{elem:?}"));
|
||||
match (precision, K::name()) {
|
||||
(Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
|
||||
_ => acc.push_str(&format!("{:?}", elem)),
|
||||
}
|
||||
} else {
|
||||
acc.push_str("<Tensor data not available>");
|
||||
}
|
||||
|
@ -1102,7 +1105,13 @@ where
|
|||
// if we are at the innermost dimension, just push its elements into the accumulator
|
||||
if summarize && self.dims()[depth] > 2 * edge_items {
|
||||
// print the starting `edge_items` elements
|
||||
self.fmt_inner_tensor(acc, depth, multi_index, (0, edge_items));
|
||||
self.fmt_inner_tensor(
|
||||
acc,
|
||||
depth,
|
||||
multi_index,
|
||||
(0, edge_items),
|
||||
print_options.precision,
|
||||
);
|
||||
acc.push_str(", ...");
|
||||
// print the last `edge_items` elements
|
||||
self.fmt_inner_tensor(
|
||||
|
@ -1110,10 +1119,17 @@ where
|
|||
depth,
|
||||
multi_index,
|
||||
(self.dims()[depth] - edge_items, self.dims()[depth]),
|
||||
print_options.precision,
|
||||
);
|
||||
} else {
|
||||
// print all the elements
|
||||
self.fmt_inner_tensor(acc, depth, multi_index, (0, self.dims()[depth]));
|
||||
self.fmt_inner_tensor(
|
||||
acc,
|
||||
depth,
|
||||
multi_index,
|
||||
(0, self.dims()[depth]),
|
||||
print_options.precision,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// otherwise, iterate through the current dimension and recursively display the inner tensors
|
||||
|
@ -1158,29 +1174,42 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// Options for Tensor pretty printing
|
||||
pub struct PrintOptions {
|
||||
/// number of elements to start summarizing tensor
|
||||
pub threshold: usize,
|
||||
|
||||
/// number of starting elements and ending elements to display
|
||||
pub edge_items: usize,
|
||||
|
||||
/// Precision for floating point numbers
|
||||
pub precision: Option<usize>,
|
||||
}
|
||||
|
||||
static PRINT_OPTS: Mutex<PrintOptions> = Mutex::new(PrintOptions::const_default());
|
||||
static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
|
||||
|
||||
impl PrintOptions {
|
||||
// We cannot use the default trait as it's not const.
|
||||
const fn const_default() -> Self {
|
||||
/// Print options with default values
|
||||
pub const fn const_default() -> Self {
|
||||
Self {
|
||||
threshold: 1000,
|
||||
edge_items: 3,
|
||||
precision: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PrintOptions {
|
||||
fn default() -> Self {
|
||||
Self::const_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set print options
|
||||
pub fn set_print_options(options: PrintOptions) {
|
||||
*PRINT_OPTS.lock().unwrap() = options
|
||||
let mut print_opts = PRINT_OPTS.write().unwrap();
|
||||
*print_opts = options;
|
||||
}
|
||||
|
||||
/// Pretty print tensors
|
||||
|
@ -1195,7 +1224,15 @@ where
|
|||
writeln!(f, "Tensor {{")?;
|
||||
|
||||
{
|
||||
let po = PRINT_OPTS.lock().unwrap();
|
||||
// Do not lock the mutex for the whole function
|
||||
let mut po = { PRINT_OPTS.read().unwrap().clone() };
|
||||
|
||||
// Override the precision if it is set from the formatter
|
||||
// This will be possible when the tensor is printed using the `{:.*}` syntax
|
||||
if let Some(precision) = f.precision() {
|
||||
po.precision = Some(precision);
|
||||
}
|
||||
|
||||
let mut acc = String::new();
|
||||
let mut multi_index = vec![0; D];
|
||||
let summarize = self.shape().num_elements() > po.threshold;
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Shape, Tensor, TensorData};
|
||||
use burn_tensor::{set_print_options, PrintOptions, Shape, Tensor, TensorData};
|
||||
|
||||
type FloatElem = <TestBackend as Backend>::FloatElem;
|
||||
type IntElem = <TestBackend as Backend>::IntElem;
|
||||
|
@ -275,6 +275,55 @@ mod tests {
|
|||
backend: {:?},
|
||||
kind: "Float",
|
||||
dtype: "f32",
|
||||
}}"#,
|
||||
tensor.device(),
|
||||
TestBackend::name(),
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
#[test]
|
||||
fn test_display_precision() {
|
||||
let tensor = Tensor::<TestBackend, 2>::full([1, 1], 0.123456789, &Default::default());
|
||||
|
||||
let output = format!("{}", tensor);
|
||||
let expected = format!(
|
||||
r#"Tensor {{
|
||||
data:
|
||||
[[0.12345679]],
|
||||
shape: [1, 1],
|
||||
device: {:?},
|
||||
backend: {:?},
|
||||
kind: "Float",
|
||||
dtype: "f32",
|
||||
}}"#,
|
||||
tensor.device(),
|
||||
TestBackend::name(),
|
||||
);
|
||||
assert_eq!(output, expected);
|
||||
|
||||
// CAN'T DO THIS BECAUSE OF GLOBAL STATE
|
||||
// let print_options = PrintOptions {
|
||||
// precision: Some(3),
|
||||
// ..Default::default()
|
||||
// };
|
||||
// set_print_options(print_options);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 2>::full([3, 2], 0.123456789, &Default::default());
|
||||
|
||||
// Set precision to 3
|
||||
let output = format!("{:.3}", tensor);
|
||||
|
||||
let expected = format!(
|
||||
r#"Tensor {{
|
||||
data:
|
||||
[[0.123, 0.123],
|
||||
[0.123, 0.123],
|
||||
[0.123, 0.123]],
|
||||
shape: [3, 2],
|
||||
device: {:?},
|
||||
backend: {:?},
|
||||
kind: "Float",
|
||||
dtype: "f32",
|
||||
}}"#,
|
||||
tensor.device(),
|
||||
TestBackend::name(),
|
||||
|
|
Loading…
Reference in New Issue