Add closeness tensor report (#2184)

* Add closeness tensor report

* Add documentation section

* Fix for no-std

* Fix epsilon formatting

* Update report.rs

* Fix import references

* Fix doc test

* Use colored crate instead of passing codes

* Small refactor to use iter directly

* Move colored dep to std

* Add missing

* Fix missing epsilon
This commit is contained in:
Dilshod Tadjibaev 2024-08-22 10:19:27 -05:00 committed by GitHub
parent 77f8121d44
commit 75a2850047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 174 additions and 2 deletions

1
Cargo.lock generated
View File

@ -706,6 +706,7 @@ dependencies = [
"burn-common", "burn-common",
"burn-tensor-testgen", "burn-tensor-testgen",
"bytemuck", "bytemuck",
"colored",
"cubecl", "cubecl",
"derive-new", "derive-new",
"half", "half",

View File

@ -407,3 +407,50 @@ Options:
- `threshold`: Maximum number of elements to display before summarizing (default: 1000) - `threshold`: Maximum number of elements to display before summarizing (default: 1000)
- `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing - `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing
(default: 3) (default: 3)
### Checking Tensor Closeness
Burn provides a utility function `check_closeness` to compare two tensors and assess their
similarity. This function is particularly useful for debugging and validating tensor operations,
especially when working with floating-point arithmetic where small numerical differences can
accumulate. It's also valuable when comparing model outputs during the process of importing models
from other frameworks, helping to ensure that the imported model produces results consistent with
the original.
Here's an example of how to use `check_closeness`:
```rust
use burn::tensor::{check_closeness, Tensor};
type B = burn::backend::NdArray;
let device = Default::default();
let tensor1 = Tensor::<B, 1>::from_floats(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
&device,
);
let tensor2 = Tensor::<B, 1>::from_floats(
[1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
&device,
);
check_closeness(&tensor1, &tensor2);
```
The `check_closeness` function compares the two input tensors element-wise, checking their
absolute differences against a range of epsilon values. It then prints a detailed report showing
the percentage of elements that are within each tolerance level.
The output provides a breakdown for different epsilon values, allowing you to assess the closeness
of the tensors at various precision levels. This is particularly helpful when dealing with
operations that may introduce small numerical discrepancies.
The function uses color-coded output to highlight the results:
- Green [PASS]: All elements are within the specified tolerance.
- Yellow [WARN]: Most elements (90% or more) are within tolerance.
- Red [FAIL]: Significant differences are detected.
This utility can be invaluable when implementing or debugging tensor operations, especially those
involving complex mathematical computations or when porting algorithms from other frameworks. It's
also an essential tool when verifying the accuracy of imported models, ensuring that the Burn
implementation produces results that closely match those of the original model.

View File

@ -15,14 +15,21 @@ default = ["std", "repr"]
doc = ["default"] doc = ["default"]
experimental-named-tensor = [] experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"] export_tests = ["burn-tensor-testgen"]
std = ["rand/std", "half/std", "num-traits/std", "burn-common/std", "burn-common/rayon"] std = [
"rand/std",
"half/std",
"num-traits/std",
"burn-common/std",
"burn-common/rayon",
"colored",
]
repr = [] repr = []
cubecl = ["dep:cubecl"] cubecl = ["dep:cubecl"]
cubecl-wgpu = ["cubecl", "cubecl/wgpu"] cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
cubecl-cuda = ["cubecl", "cubecl/cuda"] cubecl-cuda = ["cubecl", "cubecl/cuda"]
[dependencies] [dependencies]
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false} burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.14.0", optional = true } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.14.0", optional = true }
cubecl = { workspace = true, optional = true } cubecl = { workspace = true, optional = true }
@ -32,6 +39,7 @@ num-traits = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
rand_distr = { workspace = true } # use instead of statrs because it supports no_std rand_distr = { workspace = true } # use instead of statrs because it supports no_std
bytemuck = { workspace = true } bytemuck = { workspace = true }
colored = { workspace = true, optional = true }
# The same implementation of HashMap in std but with no_std support (only needs alloc crate) # The same implementation of HashMap in std but with no_std support (only needs alloc crate)
hashbrown = { workspace = true } # no_std compatible hashbrown = { workspace = true } # no_std compatible

View File

@ -33,6 +33,12 @@ pub mod ops;
/// Tensor quantization module. /// Tensor quantization module.
pub mod quantization; pub mod quantization;
#[cfg(feature = "std")]
pub use report::*;
#[cfg(feature = "std")]
mod report;
#[cfg(feature = "experimental-named-tensor")] #[cfg(feature = "experimental-named-tensor")]
mod named; mod named;
#[cfg(feature = "experimental-named-tensor")] #[cfg(feature = "experimental-named-tensor")]

View File

@ -0,0 +1,110 @@
use super::{backend::Backend, Tensor};
use colored::*;
/// Checks the closeness of two tensors and prints the results.
///
/// Compares tensors by checking the absolute difference between each element.
/// Prints the percentage of elements within specified tolerances.
///
/// # Arguments
///
/// * `output` - The output tensor.
/// * `expected` - The expected tensor.
///
/// # Example
///
/// ```no_run
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{check_closeness, Tensor};
///
/// fn example<B: Backend>() {
/// let device = Default::default();
/// let tensor1 = Tensor::<B, 1>::from_floats(
/// [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
/// &device,
/// );
/// let tensor2 = Tensor::<B, 1>::from_floats(
/// [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
/// &device,
/// );
/// check_closeness(&tensor1, &tensor2);
///}
/// ```
///
/// # Output
///
/// ```text
/// Tensor Closeness Check Results:
/// ===============================
/// Epsilon: 1e-1
/// Close elements: 10/10 (100.00%)
/// [PASS] All elements are within tolerance
///
/// Epsilon: 1e-2
/// Close elements: 10/10 (100.00%)
/// [PASS] All elements are within tolerance
///
/// Epsilon: 1e-3
/// Close elements: 9/10 (90.00%)
/// [WARN] Most elements are within tolerance
///
/// Epsilon: 1e-4
/// Close elements: 6/10 (60.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-5
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-6
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-7
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Epsilon: 1e-8
/// Close elements: 5/10 (50.00%)
/// [FAIL] Significant differences detected
///
/// Closeness check complete.
/// ```
pub fn check_closeness<B: Backend, const D: usize>(output: &Tensor<B, D>, expected: &Tensor<B, D>) {
println!("{}", "Tensor Closeness Check Results:".bold());
println!("===============================");
for epsilon in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8].iter() {
println!("{} {:.e}", "Epsilon:".bold(), epsilon);
let close = output
.clone()
.is_close(expected.clone(), Some(*epsilon), Some(*epsilon));
let data = close.clone().into_data();
let num_elements = data.num_elements();
// Count the number of elements that are close (true)
let count = data.iter::<bool>().filter(|x| *x).count();
let percentage = (count as f64 / num_elements as f64) * 100.0;
println!(
" Close elements: {}/{} ({:.2}%)",
count, num_elements, percentage
);
if percentage == 100.0 {
println!(" {} All elements are within tolerance", "[PASS]".green());
} else if percentage >= 90.0 {
println!(" {} Most elements are within tolerance", "[WARN]".yellow());
} else {
println!(" {} Significant differences detected", "[FAIL]".red());
}
println!();
}
println!("{}", "Closeness check complete.".bold());
}