mirror of https://github.com/tracel-ai/burn.git
add burn-tch support for bf16 (#303)
This commit is contained in:
parent
c5e31b272f
commit
02abc373d3
|
@ -1,11 +1,12 @@
|
||||||
use burn_tensor::Element;
|
use burn_tensor::Element;
|
||||||
use half::f16;
|
use half::{bf16, f16};
|
||||||
|
|
||||||
pub trait TchElement: Element + tch::kind::Element {}
|
pub trait TchElement: Element + tch::kind::Element {}
|
||||||
|
|
||||||
impl TchElement for f64 {}
|
impl TchElement for f64 {}
|
||||||
impl TchElement for f32 {}
|
impl TchElement for f32 {}
|
||||||
impl TchElement for f16 {}
|
impl TchElement for f16 {}
|
||||||
|
impl TchElement for bf16 {}
|
||||||
|
|
||||||
impl TchElement for i64 {}
|
impl TchElement for i64 {}
|
||||||
impl TchElement for i32 {}
|
impl TchElement for i32 {}
|
||||||
|
|
|
@ -10,5 +10,5 @@ mod tensor;
|
||||||
#[cfg(feature = "export_tests")]
|
#[cfg(feature = "export_tests")]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
pub use half::f16;
|
pub use half::{bf16, f16};
|
||||||
pub use tensor::*;
|
pub use tensor::*;
|
||||||
|
|
Loading…
Reference in New Issue