mirror of https://github.com/tracel-ai/burn.git
Add fan_out to Conv2dConfig::init (#1138)
This commit is contained in:
parent
9bd2d7b7d4
commit
76c935861c
|
@ -70,17 +70,22 @@ impl Conv2dConfig {
|
|||
self.kernel_size[1],
|
||||
];
|
||||
|
||||
let fan_in = self.channels[0] / self.groups * self.kernel_size.iter().product::<usize>();
|
||||
let k = self.kernel_size.iter().product::<usize>() / self.groups;
|
||||
let fan_in = self.channels[0] * k;
|
||||
let fan_out = self.channels[1] * k;
|
||||
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), None, device);
|
||||
.init_with(shape, Some(fan_in), Some(fan_out), device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(
|
||||
self.initializer
|
||||
.init_with([self.channels[1]], Some(fan_in), None, device),
|
||||
);
|
||||
bias = Some(self.initializer.init_with(
|
||||
[self.channels[1]],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
device,
|
||||
));
|
||||
}
|
||||
|
||||
Conv2d {
|
||||
|
@ -161,4 +166,19 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&Data::zeros(conv.weight.shape()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_out() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / sqrt(3.0),
|
||||
fan_out_only: true, // test that fan_out is passed to `init_with()`
|
||||
};
|
||||
let device = Default::default();
|
||||
let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue