add burn-tch support for bf16 (#303)

This commit is contained in:
Sunny Gonnabathula 2023-04-24 10:24:36 -05:00 committed by GitHub
parent c5e31b272f
commit 02abc373d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -1,11 +1,12 @@
use burn_tensor::Element; use burn_tensor::Element;
use half::f16; use half::{bf16, f16};
pub trait TchElement: Element + tch::kind::Element {} pub trait TchElement: Element + tch::kind::Element {}
impl TchElement for f64 {} impl TchElement for f64 {}
impl TchElement for f32 {} impl TchElement for f32 {}
impl TchElement for f16 {} impl TchElement for f16 {}
impl TchElement for bf16 {}
impl TchElement for i64 {} impl TchElement for i64 {}
impl TchElement for i32 {} impl TchElement for i32 {}

View File

@ -10,5 +10,5 @@ mod tensor;
#[cfg(feature = "export_tests")] #[cfg(feature = "export_tests")]
mod tests; mod tests;
pub use half::f16; pub use half::{bf16, f16};
pub use tensor::*; pub use tensor::*;