add bf16 element (#295)

This commit is contained in:
Sunny Gonnabathula 2023-04-12 11:36:03 -05:00 committed by GitHub
parent d4ce825725
commit 69954c14ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 1 deletions

View File

@ -1,5 +1,5 @@
use crate::Distribution; use crate::Distribution;
use half::f16; use half::{bf16, f16};
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use rand::RngCore; use rand::RngCore;
@ -124,3 +124,12 @@ make_element!(
f16::from_elem(sample) f16::from_elem(sample)
} }
); );
make_element!(
ty bf16 Precision::Half,
convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()),
random |distribution: Distribution<bf16>, rng: &mut R| {
let distribution: Distribution<f32> = distribution.convert();
let sample = distribution.sampler(rng).sample();
bf16::from_elem(sample)
}
);