perf: remove unecessary masking in dropout (#172)

This commit is contained in:
Nathaniel Simard 2023-02-26 11:43:46 -05:00 committed by GitHub
parent fb925acc73
commit 6ba9de2868
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 5 deletions

View File

@ -15,7 +15,7 @@ pub struct DropoutConfig {
/// This is an effective regularization technique as describe in the paper
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
///
/// The input is also scaled during training to `1 / (1-p)`.
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
#[derive(Clone, Debug)]
pub struct Dropout {
prob: f64,
@ -38,11 +38,11 @@ impl Dropout {
return input;
}
let random = input.random_like(Distribution::Bernoulli(self.prob));
let mask = random.equal_scalar(1);
let x = input.mask_fill(mask, 0.0_f32);
let prob_keep = 1.0 - self.prob;
let random = input.random_like(Distribution::Bernoulli(prob_keep));
let x = input * random;
x / (1.0 - self.prob)
x * (1.0 / prob_keep)
}
}