修复thor910训练失败

Signed-off-by: 王程浩 <wangchenghao14@huawei.com>
This commit is contained in:
王程浩 2023-02-06 11:42:55 +00:00 committed by cheng-hao-wang
parent 41540faf3c
commit 1cc549d2ed
1 changed files with 5 additions and 6 deletions

View File

@ -256,8 +256,7 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:
`THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation `THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation`_
<https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
The updating formulas are as follows, The updating formulas are as follows,
@ -973,15 +972,15 @@ class ThorAscend(Optimizer):
matrix_g_combine_shape = self.shape(matrix_g_inv) matrix_g_combine_shape = self.shape(matrix_g_inv)
if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001: if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_inv = self.reshape(matrix_a_inv,
(matrix_a_inv_shape[0] / 16, 16, (matrix_a_inv_shape[0] // 16, 16,
matrix_a_inv_shape[0] / 16, 16)) matrix_a_inv_shape[0] // 16, 16))
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3)) matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv) matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
matrix_g_inv_shape = self.shape(matrix_g_inv) matrix_g_inv_shape = self.shape(matrix_g_inv)
matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_inv = self.reshape(matrix_g_inv,
(matrix_g_inv_shape[0] / 16, 16, (matrix_g_inv_shape[0] // 16, 16,
matrix_g_inv_shape[0] / 16, 16)) matrix_g_inv_shape[0] // 16, 16))
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3)) matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,) matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)