mirror of https://github.com/tracel-ai/burn.git
Minor clean up of doc formatting and remove outdated TODO (#415)
This commit is contained in:
parent
4d40bde7b9
commit
4683acf726
|
@ -24,9 +24,13 @@ pub enum Initializer {
|
|||
KaimingUniform { gain: f64, fan_out_only: bool },
|
||||
/// Fills tensor with values according to the uniform version of Kaiming initialization
|
||||
KaimingNormal { gain: f64, fan_out_only: bool },
|
||||
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization described in [Understanding the difficulty of training deep feedforward neural networks](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization
|
||||
/// described in [Understanding the difficulty of training deep feedforward neural networks
|
||||
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
XavierUniform { gain: f64 },
|
||||
/// Fills tensor with values according to the normal version of Xavier Glorot initialization described in [Understanding the difficulty of training deep feedforward neural networks](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
/// Fills tensor with values according to the normal version of Xavier Glorot initialization
|
||||
/// described in [Understanding the difficulty of training deep feedforward neural networks
|
||||
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
XavierNormal { gain: f64 },
|
||||
}
|
||||
|
||||
|
@ -56,7 +60,9 @@ impl Initializer {
|
|||
) -> Tensor<B, D> {
|
||||
let shape = shape.into();
|
||||
match self {
|
||||
Initializer::Constant { value } => Tensor::<B, D>::zeros(shape) + *value, // TODO replace with fill()
|
||||
// TODO replace with full() method when implemented
|
||||
// https://github.com/burn-rs/burn/issues/413 is the issue tracking this
|
||||
Initializer::Constant { value } => Tensor::<B, D>::zeros(shape) + *value,
|
||||
Initializer::Ones => Tensor::<B, D>::ones(shape),
|
||||
Initializer::Zeros => Tensor::<B, D>::zeros(shape),
|
||||
Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max),
|
||||
|
@ -88,7 +94,7 @@ impl Initializer {
|
|||
) -> f64 {
|
||||
let fan = if fan_out_only { fan_out } else { fan_in };
|
||||
let fan = fan.expect(
|
||||
"Can't use Kaiming initialization without specifying fan. Use init_with method. ",
|
||||
"Can't use Kaiming initialization without specifying fan. Use init_with method.",
|
||||
);
|
||||
|
||||
1.0 / sqrt(fan as f64)
|
||||
|
@ -96,10 +102,10 @@ impl Initializer {
|
|||
|
||||
fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
|
||||
let fan_in = fan_in.expect(
|
||||
"Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in. ",
|
||||
"Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in.",
|
||||
);
|
||||
let fan_out = fan_out.expect(
|
||||
"Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out. ",
|
||||
"Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out.",
|
||||
);
|
||||
sqrt(2.0 / (fan_in + fan_out) as f64)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ pub struct GruConfig {
|
|||
/// If a bias should be applied during the Gru transformation.
|
||||
pub bias: bool,
|
||||
/// Gru initializer
|
||||
/// TODO: Make default Xavier initialization. https://github.com/burn-rs/burn/issues/371
|
||||
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
/// The batch size.
|
||||
|
|
Loading…
Reference in New Issue