!29708 update dim_reduce: move scale_loss from optimizer to outer, add param filter

Merge pull request !29708 from jinjiali-kali/r1.6
This commit is contained in:
i-robot 2022-02-07 11:27:22 +00:00 committed by Gitee
commit aecc3ffe58
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 48 additions and 26 deletions

View File

@ -276,7 +276,7 @@ class ParameterProcess:
return group_params
def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number):
def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number, network):
"""
get local pca mat path.
@ -285,6 +285,7 @@ def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_n
pca_mat_path (str): the path to load pca mat. Default: None.
n_component (int): pca component.
device_number (int): device number.
network (Cell): The network.
"""
if pca_mat_path is not None and os.path.exists(pca_mat_path) and os.path.isfile(pca_mat_path) and \
pca_mat_path.endswith(".npy"):
@ -312,20 +313,25 @@ def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_n
if pca_mat_exist:
pca_mat = np.load(full_pca_mat_path)
else:
data = _load_weights(weight_load_dir)
data = _load_weights(weight_load_dir, network)
pca_mat = _compute_pca_mat(data, n_component)
np.save(full_pca_mat_path, pca_mat)
_save_local_pca_mat(pca_mat, full_pca_mat_path, n_component)
return local_pca_mat_path
def _load_weights(weight_load_dir):
def _load_weights(weight_load_dir, network):
"""
load weights.
Args:
weight_load_dir (str): The weight(ckpt) file directory to be load.
network (Cell): The network.
"""
param_requires_grad_dict = {}
for param in network.trainable_params():
param_requires_grad_dict[param.name] = param.requires_grad
param_mat_tuple = ()
weight_file_list = os.listdir(weight_load_dir)
for file in weight_file_list:
@ -334,8 +340,9 @@ def _load_weights(weight_load_dir):
file_path = os.path.join(weight_load_dir, file)
param_dict = load_checkpoint(file_path)
param_tuple = ()
for _, value in param_dict.items():
param_tuple += (value.asnumpy().reshape((1, -1)),)
for key, value in param_dict.items():
if param_requires_grad_dict[key]:
param_tuple += (value.asnumpy().reshape((1, -1)),)
param = np.concatenate(param_tuple, axis=1)
param_mat_tuple += (param,)
param_mat = np.concatenate(param_mat_tuple, axis=0)

View File

@ -245,8 +245,8 @@ class AutoBoost:
"""
if self.boost_config["dim_reduce"]:
self.local_pca_mat_path = _get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path,
self.n_components, self.device_number)
optimizer = SGD(network.trainable_params(), learning_rate=1, loss_scale=optimizer.loss_scale)
self.n_components, self.device_number, network)
optimizer = SGD(network.trainable_params(), learning_rate=1)
setattr(optimizer, "dim_reduce", True)
return network, optimizer

View File

@ -223,10 +223,10 @@ class BoostTrainOneStepCell(TrainOneStepCell):
grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.gradient_accumulation_process(loss, grads, *inputs)
loss = self.gradient_accumulation_process(loss, grads, sens, *inputs)
else:
if self.enable_dim_reduce:
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
loss = F.depend(loss, self.dim_reduce(loss, grads, sens, self.weights, self.weights_clone, *inputs))
elif self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, grads))
else:
@ -256,13 +256,14 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.step += 1
return loss
def gradient_accumulation_process(self, loss, grads, *inputs):
def gradient_accumulation_process(self, loss, grads, sens, *inputs):
r"""
Gradient accumulation algorithm process.
Args:
loss (Tensor): Tensor with shape :math:`()`.
grads (tuple(Tensor)): Tuple of gradient tensors.
sens (Tensor): Tensor with shape :math:`()`.
Outputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
@ -273,8 +274,8 @@ class BoostTrainOneStepCell(TrainOneStepCell):
if self.accumulation_step >= self.max_accumulation_step:
if self.enable_dim_reduce:
loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, self.weights, self.weights_clone,
*inputs))
loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, sens, self.weights,
self.weights_clone, *inputs))
elif self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
else:
@ -449,10 +450,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
# if there is no overflow, do optimize
if not overflow:
if self.use_grad_accumulation:
loss = self.gradient_accumulation_process(loss, grads, *inputs)
loss = self.gradient_accumulation_process(loss, grads, scaling_sens_filled, *inputs)
else:
if self.enable_dim_reduce:
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
loss = F.depend(loss, self.dim_reduce(loss, grads, scaling_sens_filled, self.weights,
self.weights_clone, *inputs))
elif self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, grads))
else:

View File

@ -27,6 +27,16 @@ from mindspore.common import dtype as mstype
__all__ = ["DimReduce"]
_scale_grad = C.MultitypeFuncGraph("_scale_grad")
@_scale_grad.register("Tensor", "Tensor")
def _scale_grad_process(scale, grad):
grad = F.cast(grad, mstype.float32)
grad = P.Div()(grad, scale)
return grad
_save_weight = C.MultitypeFuncGraph("_save_weight")
@ -156,7 +166,7 @@ class DimReduce(Cell):
def _set_rho_list(self, rho):
"""set rho list info."""
self.max_search_time = 3
self.max_search_time = 2
self.rho_list = []
for i in range(self.max_search_time):
self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type))
@ -184,8 +194,7 @@ class DimReduce(Cell):
self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type)
self.broadcast_list = []
pca_rank_num = math.ceil(self.n_components / local_dim)
for i in range(pca_rank_num):
for i in range(self.rank_size):
broadcast = P.Broadcast(i)
self.broadcast_list.append(broadcast)
@ -200,13 +209,16 @@ class DimReduce(Cell):
self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk")
self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type)
self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros")
self.gk_last_back = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type),
name="gk_last_back")
self.bk_back = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk_back")
self.grad_proj_init = ParameterTuple(parameter_tuple).clone(prefix="grad_proj_init", init="zeros")
self.dn_init = ParameterTuple(parameter_tuple).clone(prefix="dn_init", init="zeros")
def construct(self, loss, old_grad, weight, weight_clone, *inputs):
def construct(self, loss, old_grad, loss_scale, weight, weight_clone, *inputs):
weight = F.depend(weight, loss)
old_grad = F.depend(old_grad, weight)
old_grad = self.hyper_map(F.partial(_scale_grad, loss_scale), old_grad)
old_loss = self.allreduce(loss) / self.rank_size
gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad)
@ -227,15 +239,19 @@ class DimReduce(Cell):
dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad)
grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad)
dn = dn_local
grad_proj = grad_proj_local
dn = self.dn_init
grad_proj = self.grad_proj_init
for broadcast in self.broadcast_list:
dn_part = broadcast(dn_local)
dn = self.hyper_map(self.add, dn, dn_part)
grad_proj_part = broadcast(grad_proj_local)
grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part)
rho = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
rho, find = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
if not find:
_save_weight(self.gk_last, self.gk_last_back)
_save_weight(self.bk, self.bk_back)
update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha),
self.grad_res_momentum, old_grad, grad_proj)
delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad)
@ -254,10 +270,7 @@ class DimReduce(Cell):
if find:
res = self.rho_list[i]
break
if not find:
_save_weight(self.gk_last, self.gk_last_back)
_save_weight(self.bk, self.bk_back)
return res
return res, find
def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs):
"""search rho."""