fix timeout 720s for resnet50_thor st

This commit is contained in:
mwang 2021-05-14 17:19:49 +08:00
parent 14e565e922
commit 88d191e0fe
4 changed files with 516 additions and 5 deletions

View File

@ -357,6 +357,7 @@ class Model:
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
self._eval_network.add_flags_recursive(thor=False)
for inputs in valid_dataset_helper:
self._eval_network.compile(*inputs)
break
@ -557,7 +558,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
@ -668,7 +669,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
callbacks (list): List of callback object. Callbacks which should be excuted
callbacks (list): List of callback object. Callbacks which should be executed
while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.

View File

@ -14,12 +14,13 @@
# ============================================================================
"""ResNet."""
import numpy as np
from scipy.stats import truncnorm
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from scipy.stats import truncnorm
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size

View File

@ -0,0 +1,510 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""THOR"""
import numpy as np
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import Dense_Thor, Conv2d_Thor, Embedding_Thor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUntils
from mindspore.parallel._auto_parallel_context import auto_parallel_context
# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
C0 = 16
def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0)
if is_A:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
else:
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll
def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
"""get matmul shape"""
split_dimA = split_dim
split_dimG = split_dim
if matrix_A_dim % split_dim == 0:
batch_w = matrix_A_dim // split_dim
else:
if matrix_A_dim < split_dim:
batch_w = 1
split_dimA = matrix_A_dim
else:
batch_w = matrix_A_dim // split_dim + 1
if matrix_G_dim % split_dim == 0:
batch_h = matrix_G_dim // split_dim
else:
if matrix_G_dim < split_dim:
batch_h = 1
split_dimG = matrix_G_dim
else:
batch_h = matrix_G_dim // split_dim + 1
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
matrix_G_shape = (batch_h, split_dimG, split_dimG)
return matrix_A_shape, matrix_G_shape
def find_net_layertype_recur(net, layertype_map):
"""get net layer type recursively."""
cells = net.name_cells()
for name in cells:
subcell = cells[name]
if subcell == net:
continue
elif isinstance(subcell, Conv2d_Thor):
layertype_map.append(Conv)
elif isinstance(subcell, Dense_Thor):
layertype_map.append(FC)
elif isinstance(subcell, Embedding_Thor):
layertype_map.append(Embedding)
elif isinstance(subcell, nn.LayerNorm):
layertype_map.append(LayerNorm)
elif isinstance(subcell, nn.BatchNorm2d):
layertype_map.append(BatchNorm)
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
layertype_map.append(Other)
else:
find_net_layertype_recur(subcell, layertype_map)
def get_net_layertype_mask(net):
layertype_map = []
find_net_layertype_recur(net, layertype_map)
return layertype_map
def get_layer_counter(layer_type, layer_counter, params, idx):
"""get layer counter"""
if layer_type in [Conv, FC, LayerNorm, BatchNorm]:
if layer_type in [LayerNorm, BatchNorm]:
if "beta" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
else:
layer_counter = layer_counter + 1
return layer_counter
def THOR(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None):
context.set_context(max_call_depth=10000)
ConvertNetUntils().convert_to_thor_net(net)
return THOR_Ascend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
split_indices=split_indices)
class THOR_Ascend(Optimizer):
"""THOR"""
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
decay_filter=lambda x: x.name not in [], split_indices=None):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(THOR_Ascend, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_A_cov = ParameterTuple(filter(lambda x: 'matrix_A' in x.name, net.get_parameters()))
self.matrix_G_cov = ParameterTuple(filter(lambda x: 'matrix_G' in x.name, net.get_parameters()))
self.A_normalizer = ParameterTuple(filter(lambda x: 'A_normalizer' in x.name, net.get_parameters()))
self.G_normalizer = ParameterTuple(filter(lambda x: 'G_normalizer' in x.name, net.get_parameters()))
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.C0 = 16
self.matrix_A_dim = ()
self.padA_flag = ()
self.device_shape_pad_flag = ()
self.diag_block_dim = 128
self.matrix_A = ()
self.matrix_G = ()
print("matrix_A_cov len is", len(self.matrix_A_cov))
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_fim_idx_map = ()
self.weight_conv_idx_map = ()
self.weight_layerType_idx_map = ()
self._process_matrix_init_and_weight_idx_map(self.net)
self.matrix_A = ParameterTuple(self.matrix_A)
self.matrix_G = ParameterTuple(self.matrix_G)
self.matrix_max_inv = ()
for i in range(len(self.matrix_A)):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.damping = damping
self.gather = P.GatherV2()
self.one = Tensor(1, mstype.int32)
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.axis = 0
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.cast = P.Cast()
self.eye = P.Eye()
self.cholesky = P.CusCholeskyTrsm()
self.vector_matmul = P.CusBatchMatMul()
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.matrix_combine = P.CusMatrixCombine()
self.slice = P.Slice()
self.expand = P.ExpandDims()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.square = P.Square()
self.inv = P.Inv()
self.matmul = P.MatMul()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_Amax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=2)
self.grad_reducer_Gmax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=4)
self.grad_reducer_A = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=6)
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8)
def _process_matrix_init_and_weight_idx_map(self, net):
"""process matrix init shape, and get weight idx map"""
layer_type_map = get_net_layertype_mask(net)
layer_counter = 0
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type == Conv and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_A_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_G_dim = out_channels
matrix_A_device_shape, matrix_A_device_dim = caculate_device_shape(matrix_A_dim, in_channels, True)
matrix_G_device_shape, matrix_G_device_dim = caculate_device_shape(matrix_G_dim, in_channels, False)
matrix_A_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_A_device_dim).astype(np.float16), matrix_A_device_shape)),
name='matrix_A_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_G_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_G_device_dim).astype(np.float16), matrix_G_device_shape)),
name="matrix_G_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_A = self.matrix_A + (matrix_A_inv,)
self.matrix_G = self.matrix_G + (matrix_G_inv,)
self.matrix_A_dim = self.matrix_A_dim + (matrix_A_dim,)
padA_flag = False
if (matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != matrix_A_dim \
and matrix_A_dim > self.diag_block_dim:
padA_flag = True
self.padA_flag = self.padA_flag + (padA_flag,)
device_shape_pad_flag = False
if matrix_A_dim != matrix_A_device_dim:
device_shape_pad_flag = True
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
out_channels = weight_shape[0]
if out_channels == 1001:
fc_matrix_A = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
name='matrix_A_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_G = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
name="matrix_G_inv_" + str(self.thor_layer_count),
requires_grad=False)
self.matrix_A = self.matrix_A + (fc_matrix_A,)
self.matrix_G = self.matrix_G + (fc_matrix_G,)
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (layer_type,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (Other,)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
def _get_Ainv_Ginv_Amax_Gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
for i in range(0, 160, 3):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if layer_type in [Conv, FC, Embedding]:
g = gradients[i]
matrix_A = self.matrix_A_cov[thor_layer_count]
matrix_G = self.matrix_G_cov[thor_layer_count]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
A_shape = self.shape(matrix_A)
A_eye = self.eye(A_shape[0], A_shape[0], mstype.float32)
G_shape = self.shape(matrix_G)
G_eye = self.eye(G_shape[0], G_shape[0], mstype.float32)
if layer_type == Conv:
A_normalizer = self.A_normalizer[conv_layer_count]
G_normalizer = self.G_normalizer[conv_layer_count]
A_normalizer = F.depend(A_normalizer, g)
G_normalizer = F.depend(G_normalizer, g)
dampingA = self.mul(damping_step, self.batch_size / A_normalizer)
dampingG = self.mul(damping_step, self.batch_size / G_normalizer)
dampingA = self.sqrt(dampingA)
matrix_A = matrix_A + dampingA * A_eye
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
A_max = P.CusFusedAbsMax1([self.matrix_A_dim[conv_layer_count],
self.matrix_A_dim[conv_layer_count]])(matrix_A_inv)
A_max = self.fused_abs_max2(A_max)
matrix_A_inv = self.matrix_combine(matrix_A_inv)
if self.padA_flag[conv_layer_count]:
matrix_A_inv = self.slice(matrix_A_inv, (0, 0), (self.matrix_A_dim[conv_layer_count],
self.matrix_A_dim[conv_layer_count]))
if self.device_shape_pad_flag[conv_layer_count]:
weight = self.params[i]
weight_shape = self.shape(weight)
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_A_inv = self.reshape(matrix_A_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_A_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_A_inv)
matrix_A_inv_shape = self.shape(self.matrix_A[thor_layer_count])
matrix_A_device_temp_shape = (matrix_A_inv_shape[0], matrix_A_inv_shape[2],
matrix_A_inv_shape[1], matrix_A_inv_shape[3])
matrix_A_inv = self.reshape(matrix_A_inv, matrix_A_device_temp_shape)
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
dampingG = self.sqrt(dampingG)
matrix_G = self.mul(matrix_G, self.loss_scale)
matrix_G = self.mul(matrix_G, self.batch_size_scale)
matrix_G = matrix_G + dampingG * G_eye
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
G_max = self.fused_abs_max2(matrix_G_inv)
G_max = self.fused_abs_max2(G_max)
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv_shape = self.shape(self.matrix_G[thor_layer_count])
matrix_G_device_temp_shape = (matrix_G_inv_shape[0], matrix_G_inv_shape[2],
matrix_G_inv_shape[1], matrix_G_inv_shape[3])
matrix_G_inv = self.reshape(matrix_G_inv, matrix_G_device_temp_shape)
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
matrix_a_max_allreduce = matrix_a_max_allreduce + (A_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (G_max,)
elif layer_type == FC:
damping = self.sqrt(damping_step)
matrix_A = matrix_A + damping * A_eye
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
matrix_G = self.mul(matrix_G, self.loss_scale)
matrix_G = self.mul(matrix_G, self.batch_size_scale)
matrix_G = matrix_G + damping * G_eye
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv)
A_max = self.fused_abs_max2(matrix_A_inv_max)
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv_shape = self.shape(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv,
(matrix_A_inv_shape[0] / 16, 16,
matrix_A_inv_shape[0] / 16, 16))
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
matrix_G_inv_max = P.CusFusedAbsMax1([1001, 1001])(matrix_G_inv)
G_max = self.fused_abs_max2(matrix_G_inv_max)
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
matrix_G_inv = P.Pad(((0, 7), (0, 7)))(matrix_G_inv)
matrix_G_inv_shape = self.shape(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv,
(matrix_G_inv_shape[0] / 16, 16,
matrix_G_inv_shape[0] / 16, 16))
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_a_max_allreduce = matrix_a_max_allreduce + (A_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (G_max,)
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
def _get_second_gradients(self, new_grads, damping_step, gradients):
"""get second gradients for thor"""
params_len = len(self.params)
for i in range(0, params_len - 1, 3):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
matrix_A = self.matrix_A[thor_layer_count]
matrix_G = self.matrix_G[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
if layer_type == FC:
g = self.cube_matmul_left_fc(matrix_G, g)
g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
new_grads = new_grads + (g, gradients[i + 1])
elif layer_type == Conv:
g = self.cube_matmul_left(matrix_G, g)
g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
new_grads = new_grads + (g, gradients[i + 1], gradients[i + 2])
return new_grads
def construct(self, gradients):
params = self.params
moments = self.moments
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
matrix_A_max_allreduce = ()
matrix_G_max_allreduce = ()
matrix_A_allreduce, matrix_G_allreduce, matrix_A_max_allreduce, matrix_G_max_allreduce = \
self._get_Ainv_Ginv_Amax_Gmax_list(gradients, damping_step, matrix_A_allreduce, matrix_G_allreduce,
matrix_A_max_allreduce, matrix_G_max_allreduce)
if self.is_distributed:
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
new_grads = ()
for i in range(0, 160, 3):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
temp_a = matrix_A_allreduce[thor_layer_count]
temp_g = matrix_G_allreduce[thor_layer_count]
matrix_A_inv_max = self.log(matrix_A_max_allreduce[thor_layer_count])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(matrix_G_max_allreduce[thor_layer_count])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[thor_layer_count],
matrix_G_max_allreduce[thor_layer_count])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
if layer_type == FC:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
elif layer_type == Conv:
A_normalizer = self.A_normalizer[conv_layer_count]
A_normalizer = F.depend(A_normalizer, g)
temp_max = self.mul(temp_max, self.batch_size / A_normalizer)
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
fake_A = self.assign(self.matrix_A[thor_layer_count], temp_a)
fake_G = self.assign(self.matrix_G[thor_layer_count], temp_g)
fake_max = self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
if i == 159:
new_grads = new_grads + (g, gradients[i + 1])
else:
new_grads = new_grads + (g, gradients[i + 1], gradients[i + 2])
gradients = new_grads
else:
new_grads = ()
gradients = self._get_second_gradients(new_grads, damping_step, gradients)
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success

View File

@ -30,7 +30,6 @@ from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.nn as nn
import mindspore.dataset as ds
from mindspore.nn.optim import THOR
from tests.st.networks.models.resnet50.src.resnet import resnet50
from tests.st.networks.models.resnet50.src.dataset import create_dataset
@ -42,7 +41,7 @@ from tests.st.networks.models.resnet50.src_thor.config import config as thor_con
from tests.st.networks.models.resnet50.src_thor.dataset import create_dataset as create_dataset_thor
from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model
from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor
from tests.st.networks.models.resnet50.src_thor.thor import THOR
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json"
MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json"