Minor clean up of doc formatting and remove outdated TODO (#415)

This commit is contained in:
Dilshod Tadjibaev 2023-06-20 09:05:28 -05:00 committed by GitHub
parent 4d40bde7b9
commit 4683acf726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 7 deletions

View File

@ -24,9 +24,13 @@ pub enum Initializer {
KaimingUniform { gain: f64, fan_out_only: bool },
/// Fills tensor with values according to the uniform version of Kaiming initialization
KaimingNormal { gain: f64, fan_out_only: bool },
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization described in [Understanding the difficulty of training deep feedforward neural networks](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization
/// described in [Understanding the difficulty of training deep feedforward neural networks
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
XavierUniform { gain: f64 },
/// Fills tensor with values according to the normal version of Xavier Glorot initialization described in [Understanding the difficulty of training deep feedforward neural networks](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
/// Fills tensor with values according to the normal version of Xavier Glorot initialization
/// described in [Understanding the difficulty of training deep feedforward neural networks
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
XavierNormal { gain: f64 },
}
@ -56,7 +60,9 @@ impl Initializer {
) -> Tensor<B, D> {
let shape = shape.into();
match self {
Initializer::Constant { value } => Tensor::<B, D>::zeros(shape) + *value, // TODO replace with fill()
// TODO replace with full() method when implemented
// https://github.com/burn-rs/burn/issues/413 is the issue tracking this
Initializer::Constant { value } => Tensor::<B, D>::zeros(shape) + *value,
Initializer::Ones => Tensor::<B, D>::ones(shape),
Initializer::Zeros => Tensor::<B, D>::zeros(shape),
Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max),
@ -88,7 +94,7 @@ impl Initializer {
) -> f64 {
let fan = if fan_out_only { fan_out } else { fan_in };
let fan = fan.expect(
"Can't use Kaiming initialization without specifying fan. Use init_with method. ",
"Can't use Kaiming initialization without specifying fan. Use init_with method.",
);
1.0 / sqrt(fan as f64)
@ -96,10 +102,10 @@ impl Initializer {
fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
let fan_in = fan_in.expect(
"Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in. ",
"Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in.",
);
let fan_out = fan_out.expect(
"Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out. ",
"Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out.",
);
sqrt(2.0 / (fan_in + fan_out) as f64)
}

View File

@ -20,7 +20,6 @@ pub struct GruConfig {
/// If a bias should be applied during the Gru transformation.
pub bias: bool,
/// Gru initializer
/// TODO: Make default Xavier initialization. https://github.com/burn-rs/burn/issues/371
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
/// The batch size.