diff --git a/burn-tensor/src/tests/ops/aggregation.rs b/burn-tensor/src/tests/ops/aggregation.rs index a260453d4..45b94a89b 100644 --- a/burn-tensor/src/tests/ops/aggregation.rs +++ b/burn-tensor/src/tests/ops/aggregation.rs @@ -9,7 +9,7 @@ mod tests { let data_actual = tensor.mean().to_data(); - assert_eq!(data_actual, Data::from([15.0 / 6.0])); + data_actual.assert_approx_eq(&Data::from([15.0 / 6.0]), 3); } #[test] @@ -45,7 +45,7 @@ mod tests { let data_actual = tensor.mean_dim(1).to_data(); - assert_eq!(data_actual, Data::from([[3.0 / 3.0], [12.0 / 3.0]])); + data_actual.assert_approx_eq(&Data::from([[3.0 / 3.0], [12.0 / 3.0]]), 3); } #[test] diff --git a/burn-tensor/src/tests/ops/div.rs b/burn-tensor/src/tests/ops/div.rs index ab48f74ed..fb3e91f07 100644 --- a/burn-tensor/src/tests/ops/div.rs +++ b/burn-tensor/src/tests/ops/div.rs @@ -14,7 +14,7 @@ mod tests { let data_actual = output.into_data(); let data_expected = Data::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); - assert_eq!(data_expected, data_actual); + data_expected.assert_approx_eq(&data_actual, 3); } #[test] diff --git a/run-checks.ps1 b/run-checks.ps1 index 253258ed6..22ea7dccf 100644 --- a/run-checks.ps1 +++ b/run-checks.ps1 @@ -1,4 +1,5 @@ -# This script runs all `burn` checks locally +# This script runs all `burn` checks locally. It may take around 15 minutes on +# the first run. # # Run `run-checks` using this command: # diff --git a/run-checks.sh b/run-checks.sh index 11941ac53..a2ae1a282 100755 --- a/run-checks.sh +++ b/run-checks.sh @@ -3,7 +3,8 @@ # Exit immediately if a command exits with a non-zero status. set -e -# This script runs all `burn` checks locally +# This script runs all `burn` checks locally. It may take around 15 minutes +# on the first run. # # Run `run-checks` using this command: #