mirror of https://github.com/tracel-ai/burn.git
add burn-tch support for bf16 (#303)
This commit is contained in:
parent
c5e31b272f
commit
02abc373d3
|
@ -1,11 +1,12 @@
|
|||
use burn_tensor::Element;
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
|
||||
pub trait TchElement: Element + tch::kind::Element {}
|
||||
|
||||
impl TchElement for f64 {}
|
||||
impl TchElement for f32 {}
|
||||
impl TchElement for f16 {}
|
||||
impl TchElement for bf16 {}
|
||||
|
||||
impl TchElement for i64 {}
|
||||
impl TchElement for i32 {}
|
||||
|
|
|
@ -10,5 +10,5 @@ mod tensor;
|
|||
#[cfg(feature = "export_tests")]
|
||||
mod tests;
|
||||
|
||||
pub use half::f16;
|
||||
pub use half::{bf16, f16};
|
||||
pub use tensor::*;
|
||||
|
|
Loading…
Reference in New Issue