mirror of https://github.com/tracel-ai/burn.git
Add closeness tensor report (#2184)
* Add closeness tensor report * Add documentation section * Fix for no-std * Fix epsilon formatting * Update report.rs * Fix import references * Fix doc test * Use colored crate instead of passing codes * Small refactor to use iter directly * Move colored dep to std * Add missing * Fix missing epsilon
This commit is contained in:
parent
77f8121d44
commit
75a2850047
|
@ -706,6 +706,7 @@ dependencies = [
|
||||||
"burn-common",
|
"burn-common",
|
||||||
"burn-tensor-testgen",
|
"burn-tensor-testgen",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
|
"colored",
|
||||||
"cubecl",
|
"cubecl",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"half",
|
"half",
|
||||||
|
|
|
@ -407,3 +407,50 @@ Options:
|
||||||
- `threshold`: Maximum number of elements to display before summarizing (default: 1000)
|
- `threshold`: Maximum number of elements to display before summarizing (default: 1000)
|
||||||
- `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing
|
- `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing
|
||||||
(default: 3)
|
(default: 3)
|
||||||
|
|
||||||
|
### Checking Tensor Closeness
|
||||||
|
|
||||||
|
Burn provides a utility function `check_closeness` to compare two tensors and assess their
|
||||||
|
similarity. This function is particularly useful for debugging and validating tensor operations,
|
||||||
|
especially when working with floating-point arithmetic where small numerical differences can
|
||||||
|
accumulate. It's also valuable when comparing model outputs during the process of importing models
|
||||||
|
from other frameworks, helping to ensure that the imported model produces results consistent with
|
||||||
|
the original.
|
||||||
|
|
||||||
|
Here's an example of how to use `check_closeness`:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use burn::tensor::{check_closeness, Tensor};
|
||||||
|
type B = burn::backend::NdArray;
|
||||||
|
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor1 = Tensor::<B, 1>::from_floats(
|
||||||
|
[1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let tensor2 = Tensor::<B, 1>::from_floats(
|
||||||
|
[1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
check_closeness(&tensor1, &tensor2);
|
||||||
|
```
|
||||||
|
|
||||||
|
The `check_closeness` function compares the two input tensors element-wise, checking their
|
||||||
|
absolute differences against a range of epsilon values. It then prints a detailed report showing
|
||||||
|
the percentage of elements that are within each tolerance level.
|
||||||
|
|
||||||
|
The output provides a breakdown for different epsilon values, allowing you to assess the closeness
|
||||||
|
of the tensors at various precision levels. This is particularly helpful when dealing with
|
||||||
|
operations that may introduce small numerical discrepancies.
|
||||||
|
|
||||||
|
The function uses color-coded output to highlight the results:
|
||||||
|
|
||||||
|
- Green [PASS]: All elements are within the specified tolerance.
|
||||||
|
- Yellow [WARN]: Most elements (90% or more) are within tolerance.
|
||||||
|
- Red [FAIL]: Significant differences are detected.
|
||||||
|
|
||||||
|
This utility can be invaluable when implementing or debugging tensor operations, especially those
|
||||||
|
involving complex mathematical computations or when porting algorithms from other frameworks. It's
|
||||||
|
also an essential tool when verifying the accuracy of imported models, ensuring that the Burn
|
||||||
|
implementation produces results that closely match those of the original model.
|
||||||
|
|
|
@ -15,14 +15,21 @@ default = ["std", "repr"]
|
||||||
doc = ["default"]
|
doc = ["default"]
|
||||||
experimental-named-tensor = []
|
experimental-named-tensor = []
|
||||||
export_tests = ["burn-tensor-testgen"]
|
export_tests = ["burn-tensor-testgen"]
|
||||||
std = ["rand/std", "half/std", "num-traits/std", "burn-common/std", "burn-common/rayon"]
|
std = [
|
||||||
|
"rand/std",
|
||||||
|
"half/std",
|
||||||
|
"num-traits/std",
|
||||||
|
"burn-common/std",
|
||||||
|
"burn-common/rayon",
|
||||||
|
"colored",
|
||||||
|
]
|
||||||
repr = []
|
repr = []
|
||||||
cubecl = ["dep:cubecl"]
|
cubecl = ["dep:cubecl"]
|
||||||
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
|
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
|
||||||
cubecl-cuda = ["cubecl", "cubecl/cuda"]
|
cubecl-cuda = ["cubecl", "cubecl/cuda"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false}
|
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
|
||||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.14.0", optional = true }
|
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.14.0", optional = true }
|
||||||
cubecl = { workspace = true, optional = true }
|
cubecl = { workspace = true, optional = true }
|
||||||
|
|
||||||
|
@ -32,6 +39,7 @@ num-traits = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
|
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
|
||||||
bytemuck = { workspace = true }
|
bytemuck = { workspace = true }
|
||||||
|
colored = { workspace = true, optional = true }
|
||||||
|
|
||||||
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
|
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
|
||||||
hashbrown = { workspace = true } # no_std compatible
|
hashbrown = { workspace = true } # no_std compatible
|
||||||
|
|
|
@ -33,6 +33,12 @@ pub mod ops;
|
||||||
/// Tensor quantization module.
|
/// Tensor quantization module.
|
||||||
pub mod quantization;
|
pub mod quantization;
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use report::*;
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod report;
|
||||||
|
|
||||||
#[cfg(feature = "experimental-named-tensor")]
|
#[cfg(feature = "experimental-named-tensor")]
|
||||||
mod named;
|
mod named;
|
||||||
#[cfg(feature = "experimental-named-tensor")]
|
#[cfg(feature = "experimental-named-tensor")]
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
use super::{backend::Backend, Tensor};
|
||||||
|
|
||||||
|
use colored::*;
|
||||||
|
|
||||||
|
/// Checks the closeness of two tensors and prints the results.
|
||||||
|
///
|
||||||
|
/// Compares tensors by checking the absolute difference between each element.
|
||||||
|
/// Prints the percentage of elements within specified tolerances.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `output` - The output tensor.
|
||||||
|
/// * `expected` - The expected tensor.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use burn_tensor::backend::Backend;
|
||||||
|
/// use burn_tensor::{check_closeness, Tensor};
|
||||||
|
///
|
||||||
|
/// fn example<B: Backend>() {
|
||||||
|
/// let device = Default::default();
|
||||||
|
/// let tensor1 = Tensor::<B, 1>::from_floats(
|
||||||
|
/// [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
|
||||||
|
/// &device,
|
||||||
|
/// );
|
||||||
|
/// let tensor2 = Tensor::<B, 1>::from_floats(
|
||||||
|
/// [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
|
||||||
|
/// &device,
|
||||||
|
/// );
|
||||||
|
/// check_closeness(&tensor1, &tensor2);
|
||||||
|
///}
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Output
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// Tensor Closeness Check Results:
|
||||||
|
/// ===============================
|
||||||
|
/// Epsilon: 1e-1
|
||||||
|
/// Close elements: 10/10 (100.00%)
|
||||||
|
/// [PASS] All elements are within tolerance
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-2
|
||||||
|
/// Close elements: 10/10 (100.00%)
|
||||||
|
/// [PASS] All elements are within tolerance
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-3
|
||||||
|
/// Close elements: 9/10 (90.00%)
|
||||||
|
/// [WARN] Most elements are within tolerance
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-4
|
||||||
|
/// Close elements: 6/10 (60.00%)
|
||||||
|
/// [FAIL] Significant differences detected
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-5
|
||||||
|
/// Close elements: 5/10 (50.00%)
|
||||||
|
/// [FAIL] Significant differences detected
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-6
|
||||||
|
/// Close elements: 5/10 (50.00%)
|
||||||
|
/// [FAIL] Significant differences detected
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-7
|
||||||
|
/// Close elements: 5/10 (50.00%)
|
||||||
|
/// [FAIL] Significant differences detected
|
||||||
|
///
|
||||||
|
/// Epsilon: 1e-8
|
||||||
|
/// Close elements: 5/10 (50.00%)
|
||||||
|
/// [FAIL] Significant differences detected
|
||||||
|
///
|
||||||
|
/// Closeness check complete.
|
||||||
|
/// ```
|
||||||
|
|
||||||
|
pub fn check_closeness<B: Backend, const D: usize>(output: &Tensor<B, D>, expected: &Tensor<B, D>) {
|
||||||
|
println!("{}", "Tensor Closeness Check Results:".bold());
|
||||||
|
println!("===============================");
|
||||||
|
|
||||||
|
for epsilon in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8].iter() {
|
||||||
|
println!("{} {:.e}", "Epsilon:".bold(), epsilon);
|
||||||
|
|
||||||
|
let close = output
|
||||||
|
.clone()
|
||||||
|
.is_close(expected.clone(), Some(*epsilon), Some(*epsilon));
|
||||||
|
let data = close.clone().into_data();
|
||||||
|
let num_elements = data.num_elements();
|
||||||
|
|
||||||
|
// Count the number of elements that are close (true)
|
||||||
|
let count = data.iter::<bool>().filter(|x| *x).count();
|
||||||
|
|
||||||
|
let percentage = (count as f64 / num_elements as f64) * 100.0;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
" Close elements: {}/{} ({:.2}%)",
|
||||||
|
count, num_elements, percentage
|
||||||
|
);
|
||||||
|
|
||||||
|
if percentage == 100.0 {
|
||||||
|
println!(" {} All elements are within tolerance", "[PASS]".green());
|
||||||
|
} else if percentage >= 90.0 {
|
||||||
|
println!(" {} Most elements are within tolerance", "[WARN]".yellow());
|
||||||
|
} else {
|
||||||
|
println!(" {} Significant differences detected", "[FAIL]".red());
|
||||||
|
}
|
||||||
|
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{}", "Closeness check complete.".bold());
|
||||||
|
}
|
Loading…
Reference in New Issue