Fix LayerNorm normalization. (#2186)

Fixes #2185.
This commit is contained in:
Guillaume Charifi 2024-08-20 13:47:15 +02:00 committed by GitHub
parent c29ed43441
commit 8053001306
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 22 additions and 1 deletions

View File

@ -67,7 +67,7 @@ impl<B: Backend> LayerNorm<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> { pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let (var, mean) = input.clone().var_mean_bias(D - 1); let (var, mean) = input.clone().var_mean_bias(D - 1);
let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
input_normalized input_normalized
.mul(self.gamma.val().unsqueeze()) .mul(self.gamma.val().unsqueeze())
@ -122,6 +122,27 @@ mod tests {
output.to_data().assert_approx_eq(&expected, 3); output.to_data().assert_approx_eq(&expected, 3);
} }
#[test]
fn layer_norm_forward_large_epsilon() {
let device = Default::default();
let module = LayerNormConfig::new(10)
.with_epsilon(1e-1)
.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
]]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([[
-0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
]]);
output.to_data().assert_approx_eq(&expected, 3);
}
#[cfg(feature = "std")] #[cfg(feature = "std")]
#[test] #[test]
fn layer_norm_backward() { fn layer_norm_backward() {