!1927 for thor codes

Merge pull request !1927 from zongha/master
This commit is contained in:
mindspore-ci-bot 2020-06-09 17:28:15 +08:00 committed by Gitee
commit 7dfc3ef677
2 changed files with 2 additions and 2 deletions

View File

@ -151,6 +151,8 @@ class THOR(Optimizer):
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
temp_max = self.mul(temp_max, self.feature_map[i])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
if i == 53:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)

View File

@ -172,7 +172,6 @@ class Conv2d_Thor(_Conv):
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fake_G = Tensor(
np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape))
self.fake_G_inv_max = Tensor(np.zeros([1,]).astype(np.float32))
self.shape = P.Shape()
self.reshape = P.Reshape()
@ -287,7 +286,6 @@ class Conv2d_Thor(_Conv):
matrix_A_inv = self.device_shape_pad(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape)
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
self.G_inv_max = self.fake_G_inv_max
self.matrix_A_inv = matrix_A_inv
self.matrix_G_inv = self.fake_G
out = self.conv2d(x, self.weight)