mirror of https://github.com/tracel-ai/burn.git
perf: remove unecessary masking in dropout (#172)
This commit is contained in:
parent
fb925acc73
commit
6ba9de2868
|
@ -15,7 +15,7 @@ pub struct DropoutConfig {
|
|||
/// This is an effective regularization technique as describe in the paper
|
||||
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
|
||||
///
|
||||
/// The input is also scaled during training to `1 / (1-p)`.
|
||||
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Dropout {
|
||||
prob: f64,
|
||||
|
@ -38,11 +38,11 @@ impl Dropout {
|
|||
return input;
|
||||
}
|
||||
|
||||
let random = input.random_like(Distribution::Bernoulli(self.prob));
|
||||
let mask = random.equal_scalar(1);
|
||||
let x = input.mask_fill(mask, 0.0_f32);
|
||||
let prob_keep = 1.0 - self.prob;
|
||||
let random = input.random_like(Distribution::Bernoulli(prob_keep));
|
||||
let x = input * random;
|
||||
|
||||
x / (1.0 - self.prob)
|
||||
x * (1.0 / prob_keep)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue