diff --git a/burn-core/src/nn/dropout.rs b/burn-core/src/nn/dropout.rs index 51f590c4f..e09062459 100644 --- a/burn-core/src/nn/dropout.rs +++ b/burn-core/src/nn/dropout.rs @@ -15,7 +15,7 @@ pub struct DropoutConfig { /// This is an effective regularization technique as describe in the paper /// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580). /// -/// The input is also scaled during training to `1 / (1-p)`. +/// The input is also scaled during training to `1 / (1 - prob_keep)`. #[derive(Clone, Debug)] pub struct Dropout { prob: f64, @@ -38,11 +38,11 @@ impl Dropout { return input; } - let random = input.random_like(Distribution::Bernoulli(self.prob)); - let mask = random.equal_scalar(1); - let x = input.mask_fill(mask, 0.0_f32); + let prob_keep = 1.0 - self.prob; + let random = input.random_like(Distribution::Bernoulli(prob_keep)); + let x = input * random; - x / (1.0 - self.prob) + x * (1.0 / prob_keep) } }