mirror of https://github.com/tracel-ai/burn.git
add bf16 element (#295)
This commit is contained in:
parent
d4ce825725
commit
69954c14ec
|
@ -1,5 +1,5 @@
|
|||
use crate::Distribution;
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::ToPrimitive;
|
||||
use rand::RngCore;
|
||||
|
||||
|
@ -124,3 +124,12 @@ make_element!(
|
|||
f16::from_elem(sample)
|
||||
}
|
||||
);
|
||||
make_element!(
|
||||
ty bf16 Precision::Half,
|
||||
convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()),
|
||||
random |distribution: Distribution<bf16>, rng: &mut R| {
|
||||
let distribution: Distribution<f32> = distribution.convert();
|
||||
let sample = distribution.sampler(rng).sample();
|
||||
bf16::from_elem(sample)
|
||||
}
|
||||
);
|
||||
|
|
Loading…
Reference in New Issue