mirror of https://github.com/tracel-ai/burn.git
add bf16 element (#295)
This commit is contained in:
parent
d4ce825725
commit
69954c14ec
|
@ -1,5 +1,5 @@
|
||||||
use crate::Distribution;
|
use crate::Distribution;
|
||||||
use half::f16;
|
use half::{bf16, f16};
|
||||||
use num_traits::ToPrimitive;
|
use num_traits::ToPrimitive;
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
|
|
||||||
|
@ -124,3 +124,12 @@ make_element!(
|
||||||
f16::from_elem(sample)
|
f16::from_elem(sample)
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
make_element!(
|
||||||
|
ty bf16 Precision::Half,
|
||||||
|
convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()),
|
||||||
|
random |distribution: Distribution<bf16>, rng: &mut R| {
|
||||||
|
let distribution: Distribution<f32> = distribution.convert();
|
||||||
|
let sample = distribution.sampler(rng).sample();
|
||||||
|
bf16::from_elem(sample)
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
Loading…
Reference in New Issue