!29708 update dim_reduce: move scale_loss from optimizer to outer, add param filter
Merge pull request !29708 from jinjiali-kali/r1.6
This commit is contained in:
commit
aecc3ffe58
|
@ -276,7 +276,7 @@ class ParameterProcess:
|
|||
return group_params
|
||||
|
||||
|
||||
def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number):
|
||||
def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number, network):
|
||||
"""
|
||||
get local pca mat path.
|
||||
|
||||
|
@ -285,6 +285,7 @@ def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_n
|
|||
pca_mat_path (str): the path to load pca mat. Default: None.
|
||||
n_component (int): pca component.
|
||||
device_number (int): device number.
|
||||
network (Cell): The network.
|
||||
"""
|
||||
if pca_mat_path is not None and os.path.exists(pca_mat_path) and os.path.isfile(pca_mat_path) and \
|
||||
pca_mat_path.endswith(".npy"):
|
||||
|
@ -312,20 +313,25 @@ def _get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_n
|
|||
if pca_mat_exist:
|
||||
pca_mat = np.load(full_pca_mat_path)
|
||||
else:
|
||||
data = _load_weights(weight_load_dir)
|
||||
data = _load_weights(weight_load_dir, network)
|
||||
pca_mat = _compute_pca_mat(data, n_component)
|
||||
np.save(full_pca_mat_path, pca_mat)
|
||||
_save_local_pca_mat(pca_mat, full_pca_mat_path, n_component)
|
||||
return local_pca_mat_path
|
||||
|
||||
|
||||
def _load_weights(weight_load_dir):
|
||||
def _load_weights(weight_load_dir, network):
|
||||
"""
|
||||
load weights.
|
||||
|
||||
Args:
|
||||
weight_load_dir (str): The weight(ckpt) file directory to be load.
|
||||
network (Cell): The network.
|
||||
"""
|
||||
param_requires_grad_dict = {}
|
||||
for param in network.trainable_params():
|
||||
param_requires_grad_dict[param.name] = param.requires_grad
|
||||
|
||||
param_mat_tuple = ()
|
||||
weight_file_list = os.listdir(weight_load_dir)
|
||||
for file in weight_file_list:
|
||||
|
@ -334,8 +340,9 @@ def _load_weights(weight_load_dir):
|
|||
file_path = os.path.join(weight_load_dir, file)
|
||||
param_dict = load_checkpoint(file_path)
|
||||
param_tuple = ()
|
||||
for _, value in param_dict.items():
|
||||
param_tuple += (value.asnumpy().reshape((1, -1)),)
|
||||
for key, value in param_dict.items():
|
||||
if param_requires_grad_dict[key]:
|
||||
param_tuple += (value.asnumpy().reshape((1, -1)),)
|
||||
param = np.concatenate(param_tuple, axis=1)
|
||||
param_mat_tuple += (param,)
|
||||
param_mat = np.concatenate(param_mat_tuple, axis=0)
|
||||
|
|
|
@ -245,8 +245,8 @@ class AutoBoost:
|
|||
"""
|
||||
if self.boost_config["dim_reduce"]:
|
||||
self.local_pca_mat_path = _get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path,
|
||||
self.n_components, self.device_number)
|
||||
optimizer = SGD(network.trainable_params(), learning_rate=1, loss_scale=optimizer.loss_scale)
|
||||
self.n_components, self.device_number, network)
|
||||
optimizer = SGD(network.trainable_params(), learning_rate=1)
|
||||
setattr(optimizer, "dim_reduce", True)
|
||||
return network, optimizer
|
||||
|
||||
|
|
|
@ -223,10 +223,10 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.gradient_accumulation_process(loss, grads, *inputs)
|
||||
loss = self.gradient_accumulation_process(loss, grads, sens, *inputs)
|
||||
else:
|
||||
if self.enable_dim_reduce:
|
||||
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
|
||||
loss = F.depend(loss, self.dim_reduce(loss, grads, sens, self.weights, self.weights_clone, *inputs))
|
||||
elif self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, grads))
|
||||
else:
|
||||
|
@ -256,13 +256,14 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
self.step += 1
|
||||
return loss
|
||||
|
||||
def gradient_accumulation_process(self, loss, grads, *inputs):
|
||||
def gradient_accumulation_process(self, loss, grads, sens, *inputs):
|
||||
r"""
|
||||
Gradient accumulation algorithm process.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Tensor with shape :math:`()`.
|
||||
grads (tuple(Tensor)): Tuple of gradient tensors.
|
||||
sens (Tensor): Tensor with shape :math:`()`.
|
||||
|
||||
Outputs:
|
||||
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
||||
|
@ -273,8 +274,8 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
|
||||
if self.accumulation_step >= self.max_accumulation_step:
|
||||
if self.enable_dim_reduce:
|
||||
loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, self.weights, self.weights_clone,
|
||||
*inputs))
|
||||
loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, sens, self.weights,
|
||||
self.weights_clone, *inputs))
|
||||
elif self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
|
||||
else:
|
||||
|
@ -449,10 +450,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|||
# if there is no overflow, do optimize
|
||||
if not overflow:
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.gradient_accumulation_process(loss, grads, *inputs)
|
||||
loss = self.gradient_accumulation_process(loss, grads, scaling_sens_filled, *inputs)
|
||||
else:
|
||||
if self.enable_dim_reduce:
|
||||
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
|
||||
loss = F.depend(loss, self.dim_reduce(loss, grads, scaling_sens_filled, self.weights,
|
||||
self.weights_clone, *inputs))
|
||||
elif self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, grads))
|
||||
else:
|
||||
|
|
|
@ -27,6 +27,16 @@ from mindspore.common import dtype as mstype
|
|||
__all__ = ["DimReduce"]
|
||||
|
||||
|
||||
_scale_grad = C.MultitypeFuncGraph("_scale_grad")
|
||||
|
||||
|
||||
@_scale_grad.register("Tensor", "Tensor")
|
||||
def _scale_grad_process(scale, grad):
|
||||
grad = F.cast(grad, mstype.float32)
|
||||
grad = P.Div()(grad, scale)
|
||||
return grad
|
||||
|
||||
|
||||
_save_weight = C.MultitypeFuncGraph("_save_weight")
|
||||
|
||||
|
||||
|
@ -156,7 +166,7 @@ class DimReduce(Cell):
|
|||
|
||||
def _set_rho_list(self, rho):
|
||||
"""set rho list info."""
|
||||
self.max_search_time = 3
|
||||
self.max_search_time = 2
|
||||
self.rho_list = []
|
||||
for i in range(self.max_search_time):
|
||||
self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type))
|
||||
|
@ -184,8 +194,7 @@ class DimReduce(Cell):
|
|||
self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type)
|
||||
|
||||
self.broadcast_list = []
|
||||
pca_rank_num = math.ceil(self.n_components / local_dim)
|
||||
for i in range(pca_rank_num):
|
||||
for i in range(self.rank_size):
|
||||
broadcast = P.Broadcast(i)
|
||||
self.broadcast_list.append(broadcast)
|
||||
|
||||
|
@ -200,13 +209,16 @@ class DimReduce(Cell):
|
|||
self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk")
|
||||
self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type)
|
||||
self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros")
|
||||
|
||||
self.gk_last_back = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type),
|
||||
name="gk_last_back")
|
||||
self.bk_back = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk_back")
|
||||
self.grad_proj_init = ParameterTuple(parameter_tuple).clone(prefix="grad_proj_init", init="zeros")
|
||||
self.dn_init = ParameterTuple(parameter_tuple).clone(prefix="dn_init", init="zeros")
|
||||
|
||||
def construct(self, loss, old_grad, weight, weight_clone, *inputs):
|
||||
def construct(self, loss, old_grad, loss_scale, weight, weight_clone, *inputs):
|
||||
weight = F.depend(weight, loss)
|
||||
old_grad = F.depend(old_grad, weight)
|
||||
old_grad = self.hyper_map(F.partial(_scale_grad, loss_scale), old_grad)
|
||||
old_loss = self.allreduce(loss) / self.rank_size
|
||||
|
||||
gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad)
|
||||
|
@ -227,15 +239,19 @@ class DimReduce(Cell):
|
|||
|
||||
dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad)
|
||||
grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad)
|
||||
dn = dn_local
|
||||
grad_proj = grad_proj_local
|
||||
dn = self.dn_init
|
||||
grad_proj = self.grad_proj_init
|
||||
for broadcast in self.broadcast_list:
|
||||
dn_part = broadcast(dn_local)
|
||||
dn = self.hyper_map(self.add, dn, dn_part)
|
||||
grad_proj_part = broadcast(grad_proj_local)
|
||||
grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part)
|
||||
|
||||
rho = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
|
||||
rho, find = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
|
||||
if not find:
|
||||
_save_weight(self.gk_last, self.gk_last_back)
|
||||
_save_weight(self.bk, self.bk_back)
|
||||
|
||||
update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha),
|
||||
self.grad_res_momentum, old_grad, grad_proj)
|
||||
delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad)
|
||||
|
@ -254,10 +270,7 @@ class DimReduce(Cell):
|
|||
if find:
|
||||
res = self.rho_list[i]
|
||||
break
|
||||
if not find:
|
||||
_save_weight(self.gk_last, self.gk_last_back)
|
||||
_save_weight(self.bk, self.bk_back)
|
||||
return res
|
||||
return res, find
|
||||
|
||||
def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs):
|
||||
"""search rho."""
|
||||
|
|
Loading…
Reference in New Issue