From 76c935861cced45793165f94bece24e9e456a667 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 12 Jan 2024 16:02:05 -0500 Subject: [PATCH] Add fan_out to Conv2dConfig::init (#1138) --- burn-core/src/nn/conv/conv2d.rs | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index 64906e777..7d5dc36f7 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -70,17 +70,22 @@ impl Conv2dConfig { self.kernel_size[1], ]; - let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::(); + let k = self.kernel_size.iter().product::() / self.groups; + let fan_in = self.channels[0] * k; + let fan_out = self.channels[1] * k; + let weight = self .initializer - .init_with(shape, Some(fan_in), None, device); + .init_with(shape, Some(fan_in), Some(fan_out), device); let mut bias = None; if self.bias { - bias = Some( - self.initializer - .init_with([self.channels[1]], Some(fan_in), None, device), - ); + bias = Some(self.initializer.init_with( + [self.channels[1]], + Some(fan_in), + Some(fan_out), + device, + )); } Conv2d { @@ -161,4 +166,19 @@ mod tests { .to_data() .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); } + + #[test] + fn initializer_fan_out() { + TestBackend::seed(0); + + let init = Initializer::KaimingUniform { + gain: 1.0 / sqrt(3.0), + fan_out_only: true, // test that fan_out is passed to `init_with()` + }; + let device = Default::default(); + let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone()); + let _ = config.init::(&device); + + assert_eq!(config.initializer, init); + } }