This commit is contained in:
Nathaniel Simard 2024-02-01 14:50:38 -05:00 committed by syl20bnr
parent 23d653ac0b
commit d646417614
5 changed files with 43 additions and 17 deletions

View File

@ -11,7 +11,7 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-autodiff"
version.workspace = true
[features]
default = ["export_tests"]
default = []
export_tests = ["burn-tensor-testgen"]
[dependencies]
@ -21,3 +21,10 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.1", opt
derive-new = { workspace = true }
spin = { workspace = true }
[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false, features = [
"export_tests",
] }

View File

@ -40,6 +40,16 @@ std = [
]
doc = [
"std",
# Backends
"dataset",
"candle",
"fusion",
"ndarray",
"tch",
"wgpu",
"vision",
"autodiff",
# Doc features
"burn-candle/doc",
"burn-common/doc",
"burn-dataset/doc",

View File

@ -48,13 +48,9 @@ blas-openblas-system = [
# ** Please make sure all dependencies support no_std when std is disabled **
burn-autodiff = { path = "../burn-autodiff", version = "0.12.1", features = [
"export_tests",
], optional = true }
burn-autodiff = { path = "../burn-autodiff", version = "0.12.1", optional = true }
burn-common = { path = "../burn-common", version = "0.12.1", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false, features = [
"export_tests",
] }
burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false }
matrixmultiply = { workspace = true, default-features = false }
rayon = { workspace = true, optional = true }
@ -67,5 +63,13 @@ openblas-src = { workspace = true, optional = true }
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.12.1", default-features = false, features = [
"export_tests",
] }
burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false, features = [
"export_tests",
] }
[package.metadata.docs.rs]
features = ["doc"]

View File

@ -16,6 +16,7 @@ mod tensor;
mod tests;
pub use half::{bf16, f16};
pub(crate) use tensor::check::macros::check;
pub use tensor::*;
pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as

View File

@ -818,21 +818,25 @@ impl TensorError {
}
}
/// We use a macro for all checks, since the panic message file and line number will match the
/// function that does the check instead of a the generic error.rs crate private unrelated file
/// and line number.
#[macro_export(local_inner_macros)]
macro_rules! check {
($check:expr) => {
if let TensorCheck::Failed(check) = $check {
core::panic!("{}", check.format());
}
};
/// Module where we defined macros that can be used only in the project.
pub(crate) mod macros {
/// We use a macro for all checks, since the panic message file and line number will match the
/// function that does the check instead of a the generic error.rs crate private unrelated file
/// and line number.
macro_rules! check {
($check:expr) => {
if let TensorCheck::Failed(check) = $check {
core::panic!("{}", check.format());
}
};
}
pub(crate) use check;
}
#[cfg(test)]
mod tests {
use super::*;
use macros::check;
#[test]
#[should_panic]