mirror of https://github.com/tracel-ai/burn.git
Fix where sqrt should’ve been powf(2.) (#1052)
This commit is contained in:
parent
ef09b637b9
commit
3066196297
|
@ -108,8 +108,10 @@ impl<B: Backend> GroupNorm<B> {
|
|||
let input = input.reshape([batch_size, self.num_groups, hidden_size]);
|
||||
|
||||
let mean = input.clone().sum_dim(2) / hidden_size as f64;
|
||||
let var = input.clone().sqrt().sum_dim(2) / hidden_size as f64;
|
||||
let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon));
|
||||
let input = input.sub(mean);
|
||||
|
||||
let var = input.clone().powf(2.).sum_dim(2) / hidden_size as f64;
|
||||
let input_normalized = input.div(var.sqrt().add_scalar(self.epsilon));
|
||||
|
||||
if self.affine {
|
||||
let mut affine_shape = [1; D];
|
||||
|
@ -142,7 +144,7 @@ mod tests {
|
|||
|
||||
let input = Tensor::from_data(Data::from([
|
||||
[
|
||||
[-0.3034f32, 0.2726, -0.9659],
|
||||
[-0.3034, 0.2726, -0.9659],
|
||||
[-1.1845, -1.3236, 0.0172],
|
||||
[1.9507, 1.2554, -0.8625],
|
||||
[1.0682, 0.3604, 0.3985],
|
||||
|
@ -208,20 +210,20 @@ mod tests {
|
|||
|
||||
let input = Tensor::from_data(Data::from([
|
||||
[
|
||||
[-0.3034f32, 0.2726, -0.9659],
|
||||
[-1.1845, -1.3236, 0.0172],
|
||||
[1.9507, 1.2554, -0.8625],
|
||||
[1.0682, 0.3604, 0.3985],
|
||||
[-0.4957, -0.4461, -0.9721],
|
||||
[1.5157, -0.1546, -0.5596],
|
||||
[0.3345, 0.4429, 0.6639],
|
||||
[0.5041, 0.4175, 0.8437],
|
||||
[0.6159, 0.3758, 0.4071],
|
||||
[0.5417, 0.5785, 0.7671],
|
||||
[0.3837, 0.9883, 0.0420],
|
||||
[0.4808, 0.8989, 0.6144],
|
||||
],
|
||||
[
|
||||
[-1.6698, -0.4040, -0.7927],
|
||||
[0.3736, -0.0975, -0.1351],
|
||||
[-0.9461, 0.5461, -0.6334],
|
||||
[-1.0919, -0.1158, 0.1213],
|
||||
[-0.9535, 0.1281, 0.4372],
|
||||
[-0.2845, 0.3488, 0.5641],
|
||||
[0.3930, 0.2098, 0.0602],
|
||||
[0.2298, 0.9425, 0.0333],
|
||||
[0.7409, 0.8172, 0.8879],
|
||||
[0.4846, 0.0486, 0.2029],
|
||||
[0.6741, 0.9765, 0.6864],
|
||||
[0.2827, 0.5534, 0.2125],
|
||||
],
|
||||
]));
|
||||
|
||||
|
@ -230,20 +232,20 @@ mod tests {
|
|||
output.to_data().assert_approx_eq(
|
||||
&Data::from([
|
||||
[
|
||||
[0.4560, 1.4014, -0.6313],
|
||||
[-0.9901, -1.2184, 0.9822],
|
||||
[1.4254, 0.6360, -1.7682],
|
||||
[0.4235, -0.3800, -0.3367],
|
||||
[-0.3890, -0.3268, -0.9862],
|
||||
[2.1325, 0.0386, -0.4691],
|
||||
[-1.1694, -0.5353, 0.7572],
|
||||
[-0.1775, -0.6838, 1.8087],
|
||||
[0.5205, -1.3107, -1.0723],
|
||||
[-0.0459, 0.2351, 1.6734],
|
||||
[-0.5796, 1.3218, -1.6544],
|
||||
[-0.2744, 1.0406, 0.1459],
|
||||
],
|
||||
[
|
||||
[-1.8797, 0.0777, -0.5234],
|
||||
[1.2802, 0.5517, 0.4935],
|
||||
[-1.0102, 1.5327, -0.4773],
|
||||
[-1.2587, 0.4047, 0.8088],
|
||||
[-1.9074, 0.1691, 0.7625],
|
||||
[-0.6230, 0.5928, 1.0061],
|
||||
[0.2665, -0.3320, -0.8205],
|
||||
[-0.2667, 2.0612, -0.9085],
|
||||
[0.6681, 0.9102, 1.1345],
|
||||
[-0.1453, -1.5287, -1.0389],
|
||||
[0.4253, 1.5962, 0.4731],
|
||||
[-1.0903, -0.0419, -1.3623],
|
||||
],
|
||||
]),
|
||||
3,
|
||||
|
|
Loading…
Reference in New Issue