Add fan_out to Conv2dConfig::init (#1138)

This commit is contained in:
Guillaume Lagrange 2024-01-12 16:02:05 -05:00 committed by GitHub
parent 9bd2d7b7d4
commit 76c935861c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 6 deletions

View File

@ -70,17 +70,22 @@ impl Conv2dConfig {
self.kernel_size[1],
];
let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::<usize>();
let k = self.kernel_size.iter().product::<usize>() / self.groups;
let fan_in = self.channels[0] * k;
let fan_out = self.channels[1] * k;
let weight = self
.initializer
.init_with(shape, Some(fan_in), None, device);
.init_with(shape, Some(fan_in), Some(fan_out), device);
let mut bias = None;
if self.bias {
bias = Some(
self.initializer
.init_with([self.channels[1]], Some(fan_in), None, device),
);
bias = Some(self.initializer.init_with(
[self.channels[1]],
Some(fan_in),
Some(fan_out),
device,
));
}
Conv2d {
@ -161,4 +166,19 @@ mod tests {
.to_data()
.assert_approx_eq(&Data::zeros(conv.weight.shape()), 3);
}
#[test]
fn initializer_fan_out() {
TestBackend::seed(0);
let init = Initializer::KaimingUniform {
gain: 1.0 / sqrt(3.0),
fan_out_only: true, // test that fan_out is passed to `init_with()`
};
let device = Default::default();
let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
}