diff --git a/burn-core/src/nn/initializer.rs b/burn-core/src/nn/initializer.rs index ef20c289a..98d4f1125 100644 --- a/burn-core/src/nn/initializer.rs +++ b/burn-core/src/nn/initializer.rs @@ -24,9 +24,13 @@ pub enum Initializer { KaimingUniform { gain: f64, fan_out_only: bool }, /// Fills tensor with values according to the uniform version of Kaiming initialization KaimingNormal { gain: f64, fan_out_only: bool }, - /// Fills tensor with values according to the uniform version of Xavier Glorot initialization described in [Understanding the difficulty of training deep feedforward neural networks](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) + /// Fills tensor with values according to the uniform version of Xavier Glorot initialization + /// described in [Understanding the difficulty of training deep feedforward neural networks + /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) XavierUniform { gain: f64 }, - /// Fills tensor with values according to the normal version of Xavier Glorot initialization described in [Understanding the difficulty of training deep feedforward neural networks](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) + /// Fills tensor with values according to the normal version of Xavier Glorot initialization + /// described in [Understanding the difficulty of training deep feedforward neural networks + /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) XavierNormal { gain: f64 }, } @@ -56,7 +60,9 @@ impl Initializer { ) -> Tensor { let shape = shape.into(); match self { - Initializer::Constant { value } => Tensor::::zeros(shape) + *value, // TODO replace with fill() + // TODO replace with full() method when implemented + // https://github.com/burn-rs/burn/issues/413 is the issue tracking this + Initializer::Constant { value } => Tensor::::zeros(shape) + *value, Initializer::Ones => Tensor::::ones(shape), Initializer::Zeros => Tensor::::zeros(shape), Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max), @@ -88,7 +94,7 @@ impl Initializer { ) -> f64 { let fan = if fan_out_only { fan_out } else { fan_in }; let fan = fan.expect( - "Can't use Kaiming initialization without specifying fan. Use init_with method. ", + "Can't use Kaiming initialization without specifying fan. Use init_with method.", ); 1.0 / sqrt(fan as f64) @@ -96,10 +102,10 @@ impl Initializer { fn xavier_std(&self, fan_in: Option, fan_out: Option) -> f64 { let fan_in = fan_in.expect( - "Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in. ", + "Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in.", ); let fan_out = fan_out.expect( - "Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out. ", + "Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out.", ); sqrt(2.0 / (fan_in + fan_out) as f64) } diff --git a/burn-core/src/nn/rnn/gru.rs b/burn-core/src/nn/rnn/gru.rs index 8f1c558d3..cebebca14 100644 --- a/burn-core/src/nn/rnn/gru.rs +++ b/burn-core/src/nn/rnn/gru.rs @@ -20,7 +20,6 @@ pub struct GruConfig { /// If a bias should be applied during the Gru transformation. pub bias: bool, /// Gru initializer - /// TODO: Make default Xavier initialization. https://github.com/burn-rs/burn/issues/371 #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// The batch size.