!1482 add globalbn for vm

Merge pull request !1482 from JichenZhao/layernorm_mean_var_shape
This commit is contained in:
mindspore-ci-bot 2020-05-28 10:34:32 +08:00 committed by Gitee
commit a95c544499
2 changed files with 17 additions and 13 deletions

View File

@ -156,19 +156,23 @@ class _BatchNorm(Cell):
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
else:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y = self.bn_train(x,
self.gamma,

View File

@ -184,7 +184,7 @@ class AvgPool2d(_PoolNd):
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> pool = nn.AvgPool2d(kernel_size=3, strides=1)
>>> pool = nn.AvgPool2d(kernel_size=3, stride=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
[[[[5. 5. 9. 9.]
[8. 4. 3. 0.]