forked from mindspore-Ecosystem/mindspore
add dimesion reduce training
This commit is contained in:
parent
600c6421f1
commit
b96eba92b4
|
@ -26,6 +26,7 @@ from .less_batch_normalization import LessBN
|
|||
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
|
||||
from .grad_accumulation import GradientAccumulation
|
||||
from .adasum import AdaSum
|
||||
from .dim_reduce import DimReduce
|
||||
|
||||
|
||||
__all__ = ['AutoBoost',
|
||||
|
@ -34,4 +35,4 @@ __all__ = ['AutoBoost',
|
|||
'LessBN',
|
||||
'GradientFreeze', 'FreezeOpt', 'freeze_cell',
|
||||
'GradientAccumulation',
|
||||
'AdaSum']
|
||||
'AdaSum', 'DimReduce']
|
||||
|
|
|
@ -22,7 +22,6 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations._inner_ops import Send, Receive
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
__all__ = ["AdaSum"]
|
||||
|
@ -57,8 +56,6 @@ def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisi
|
|||
"""send result and receive result."""
|
||||
if parameter_divisibility:
|
||||
recv_part = P.Squeeze()(recv_part)
|
||||
if F.shape(recv_part) is None:
|
||||
recv_part = Tensor([recv_part])
|
||||
local_part = F.depend(local_part, recv_part)
|
||||
eps = 1e-12
|
||||
value_0 = P.ReduceSum()(local_part * recv_part) + eps
|
||||
|
@ -128,10 +125,6 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
|
|||
recv_part = _receive_before_send(delta_w, send, recv)
|
||||
|
||||
recv_part = P.Squeeze()(recv_part)
|
||||
if F.shape(recv_part) is None:
|
||||
recv_part = Tensor([recv_part])
|
||||
if F.shape(delta_w) is None:
|
||||
delta_w = Tensor([delta_w])
|
||||
recv_part = P.Reshape()(recv_part, (-1,))
|
||||
delta_w = P.Reshape()(delta_w, (-1,))
|
||||
|
||||
|
|
|
@ -13,15 +13,23 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""base process"""
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
from scipy import linalg as la
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.optim import LARS
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import Parameter
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.parallel._utils import _get_global_rank
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from .less_batch_normalization import CommonHeadLastFN
|
||||
|
||||
|
||||
__all__ = ["OptimizerProcess", "ParameterProcess"]
|
||||
__all__ = ["OptimizerProcess", "ParameterProcess", "get_local_pca_mat_path", "load_local_pca_mat"]
|
||||
|
||||
|
||||
class OptimizerProcess:
|
||||
|
@ -265,3 +273,150 @@ class ParameterProcess:
|
|||
else:
|
||||
group_params.append({"params": params_value})
|
||||
return group_params
|
||||
|
||||
|
||||
def get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number):
|
||||
"""
|
||||
get local pca mat path.
|
||||
|
||||
Args:
|
||||
weight_load_dir (str): The weight(ckpt) file directory to be load.
|
||||
pca_mat_path (str): the path to load pca mat. Default: None.
|
||||
n_component (int): pca component.
|
||||
device_number (int): device number.
|
||||
"""
|
||||
if pca_mat_path is not None and os.path.exists(pca_mat_path) and os.path.isfile(pca_mat_path) and \
|
||||
pca_mat_path.endswith(".npy"):
|
||||
full_pca_mat_path = pca_mat_path
|
||||
pca_mat_exist = True
|
||||
|
||||
else:
|
||||
if weight_load_dir is None or not os.path.exists(weight_load_dir) or not os.path.isdir(weight_load_dir):
|
||||
raise ValueError("The weight_load_dir: {} is None / not exists / not directory.".format(weight_load_dir))
|
||||
|
||||
full_pca_mat_path = os.path.join(weight_load_dir, "pca_mat_temp.npy")
|
||||
pca_mat_exist = False
|
||||
|
||||
rank = _get_global_rank()
|
||||
local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank) + ".npy"
|
||||
if os.path.exists(local_pca_mat_path):
|
||||
os.remove(local_pca_mat_path)
|
||||
if rank % device_number != 0:
|
||||
return local_pca_mat_path
|
||||
|
||||
if pca_mat_exist:
|
||||
pca_mat = np.load(full_pca_mat_path)
|
||||
else:
|
||||
data = _load_weights(weight_load_dir)
|
||||
pca_mat = _compute_pca_mat(data, n_component)
|
||||
np.save(full_pca_mat_path, pca_mat)
|
||||
_save_local_pca_mat(pca_mat, full_pca_mat_path, n_component)
|
||||
return local_pca_mat_path
|
||||
|
||||
|
||||
def _load_weights(weight_load_dir):
|
||||
"""
|
||||
load weights.
|
||||
|
||||
Args:
|
||||
weight_load_dir (str): The weight(ckpt) file directory to be load.
|
||||
"""
|
||||
param_mat = None
|
||||
weight_file_list = os.listdir(weight_load_dir)
|
||||
for file in weight_file_list:
|
||||
if not file.endswith('.ckpt'):
|
||||
continue
|
||||
file_path = os.path.join(weight_load_dir, file)
|
||||
param_dict = load_checkpoint(file_path)
|
||||
param = None
|
||||
for _, value in param_dict.items():
|
||||
if param is None:
|
||||
param = value.asnumpy().reshape((1, -1))
|
||||
else:
|
||||
param = np.hstack((param, value.asnumpy().reshape((1, -1))))
|
||||
if param_mat is None:
|
||||
param_mat = param
|
||||
else:
|
||||
param_mat = np.vstack((param_mat, param))
|
||||
return param_mat
|
||||
|
||||
|
||||
def _compute_pca_mat(data, n_component):
|
||||
"""
|
||||
compute pca mat.
|
||||
|
||||
Args:
|
||||
data : array-like of shape (n_samples, n_features)
|
||||
Training data, where `n_samples` is the number of samples
|
||||
and `n_features` is the number of features.
|
||||
n_component (int): pca component.
|
||||
"""
|
||||
if data.shape[0] < n_component:
|
||||
raise ValueError("The samples: {} is less than: n_component {}.".format(data.shape[0], n_component))
|
||||
mean = np.mean(data, axis=0)
|
||||
data -= mean
|
||||
u, _, v = la.svd(data, full_matrices=False)
|
||||
_, v = _svd_flip(u, v)
|
||||
components = v[:n_component]
|
||||
return components
|
||||
|
||||
|
||||
def _svd_flip(u, v):
|
||||
"""
|
||||
svd flip.
|
||||
|
||||
Args:
|
||||
u (ndarray): the output of `linalg.svd`.
|
||||
v (ndarray): the output of `linalg.svd`.
|
||||
"""
|
||||
max_abs_cols = np.argmax(np.abs(u), axis=0)
|
||||
signs = np.sign(u[max_abs_cols, range(u.shape[1])])
|
||||
u *= signs
|
||||
v *= signs[:, np.newaxis]
|
||||
return u, v
|
||||
|
||||
|
||||
def _save_local_pca_mat(pca_mat, full_pca_mat_path, n_component):
|
||||
"""
|
||||
save pca mat.
|
||||
|
||||
Args:
|
||||
pca_mat (numpy.ndarray): pca mat to be saved.
|
||||
full_pca_mat_path (str): the path of full pca mat.
|
||||
n_component (int): pca component.
|
||||
"""
|
||||
rank_size = get_group_size()
|
||||
local_dim = math.ceil(n_component / rank_size)
|
||||
for rank_id in range(rank_size):
|
||||
start_index = rank_id * local_dim
|
||||
end_index = (rank_id + 1) * local_dim
|
||||
pca_start_index = min(n_component, start_index)
|
||||
pca_end_index = min(n_component, end_index)
|
||||
p_local = np.zeros([local_dim + 1, pca_mat.shape[1]])
|
||||
if pca_start_index != pca_end_index:
|
||||
p_local[0: pca_end_index - pca_start_index, :] = pca_mat[pca_start_index: pca_end_index, :]
|
||||
local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank_id) + ".npy"
|
||||
np.save(local_pca_mat_path, p_local)
|
||||
|
||||
|
||||
def load_local_pca_mat(local_pca_mat_path, n_component):
|
||||
"""
|
||||
load pca mat.
|
||||
|
||||
Args:
|
||||
local_pca_mat_path (str): local pca mat file path.
|
||||
n_component (int): pca component.
|
||||
"""
|
||||
rank_size = get_group_size()
|
||||
local_dim = math.ceil(n_component / rank_size)
|
||||
while True:
|
||||
if os.path.exists(local_pca_mat_path):
|
||||
break
|
||||
time.sleep(5)
|
||||
while True:
|
||||
pca_mat = np.load(local_pca_mat_path)
|
||||
if pca_mat.shape[0] == local_dim + 1:
|
||||
break
|
||||
time.sleep(5)
|
||||
pca_mat = pca_mat[:-1, :]
|
||||
return pca_mat
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
# ============================================================================
|
||||
"""boost"""
|
||||
import threading
|
||||
from mindspore.nn.optim import SGD
|
||||
from .less_batch_normalization import LessBN
|
||||
from .grad_freeze import GradientFreeze
|
||||
from .base import OptimizerProcess, ParameterProcess
|
||||
from .base import get_local_pca_mat_path
|
||||
|
||||
|
||||
__all__ = ["AutoBoost"]
|
||||
|
@ -27,17 +29,20 @@ _boost_config_level = {
|
|||
"less_bn": False,
|
||||
"grad_freeze": False,
|
||||
"adasum": False,
|
||||
"grad_accumulation": False},
|
||||
"grad_accumulation": False,
|
||||
"dim_reduce": False},
|
||||
"O1": {
|
||||
"less_bn": True,
|
||||
"grad_freeze": True,
|
||||
"adasum": False,
|
||||
"grad_accumulation": False},
|
||||
"grad_accumulation": False,
|
||||
"dim_reduce": False},
|
||||
"O2": {
|
||||
"less_bn": True,
|
||||
"grad_freeze": True,
|
||||
"adasum": True,
|
||||
"grad_accumulation": False}}
|
||||
"grad_accumulation": False,
|
||||
"dim_reduce": False}}
|
||||
|
||||
|
||||
class AutoBoost:
|
||||
|
@ -57,10 +62,12 @@ class AutoBoost:
|
|||
"less_bn": false,
|
||||
"grad_freeze": false,
|
||||
"adasum": false,
|
||||
"grad_accumulation": false
|
||||
"grad_accumulation": false,
|
||||
"dim_reduce": false
|
||||
},
|
||||
"common": {
|
||||
"gradient_split_groups": [50, 100]
|
||||
"gradient_split_groups": [50, 100],
|
||||
"device_number": 8
|
||||
},
|
||||
"less_bn": {
|
||||
"fn_flag": true,
|
||||
|
@ -73,10 +80,20 @@ class AutoBoost:
|
|||
"total_steps": 65536
|
||||
},
|
||||
"adasum": {
|
||||
"device_number": 8
|
||||
|
||||
},
|
||||
"grad_accumulation": {
|
||||
"grad_accumulation_step": 1
|
||||
},
|
||||
"dim_reduce": {
|
||||
"ls_weight_decay": 0.0001,
|
||||
"rho": 0.55,
|
||||
"gamma": 0.9,
|
||||
"alpha": 0.001,
|
||||
"sigma": 0.4,
|
||||
"n_components": 32,
|
||||
"pca_mat_path": None,
|
||||
"weight_load_dir": None
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,6 +137,15 @@ class AutoBoost:
|
|||
self.gradient_groups = None
|
||||
self.device_number = 8
|
||||
self.grad_accumulation_step = 1
|
||||
self.ls_weight_decay = 0.0001
|
||||
self.rho = 0.55
|
||||
self.gamma = 0.9
|
||||
self.alpha = 0.001
|
||||
self.sigma = 0.4
|
||||
self.n_components = 32
|
||||
self.pca_mat_path = None
|
||||
self.weight_load_dir = None
|
||||
self.local_pca_mat_path = None
|
||||
self.boost_config = self._get_configuration(level, self.boost_config_dict)
|
||||
self._param_processer = ParameterProcess()
|
||||
|
||||
|
@ -141,6 +167,13 @@ class AutoBoost:
|
|||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
"""
|
||||
if self.boost_config["dim_reduce"]:
|
||||
self.local_pca_mat_path = get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path,
|
||||
self.n_components, self.device_number)
|
||||
optimizer = SGD(network.trainable_params(), learning_rate=1, loss_scale=optimizer.loss_scale)
|
||||
setattr(optimizer, "dim_reduce", True)
|
||||
return network, optimizer
|
||||
|
||||
if self.boost_config["less_bn"]:
|
||||
network = LessBN(network, fn_flag=self._fn_flag)
|
||||
optimizer_process = OptimizerProcess(optimizer)
|
||||
|
@ -168,6 +201,8 @@ class AutoBoost:
|
|||
Args:
|
||||
network (Cell): The inference network.
|
||||
"""
|
||||
if self.boost_config["dim_reduce"]:
|
||||
return network
|
||||
if self.boost_config["less_bn"]:
|
||||
network = LessBN(network)
|
||||
|
||||
|
@ -204,6 +239,30 @@ class AutoBoost:
|
|||
gradient_groups = list(gradient_groups)
|
||||
self.gradient_groups = gradient_groups
|
||||
|
||||
def set_ls_weight_decay(self, ls_weight_decay):
|
||||
self.ls_weight_decay = ls_weight_decay
|
||||
|
||||
def set_rho(self, rho):
|
||||
self.rho = rho
|
||||
|
||||
def set_gamma(self, gamma):
|
||||
self.gamma = gamma
|
||||
|
||||
def set_alpha(self, alpha):
|
||||
self.alpha = alpha
|
||||
|
||||
def set_sigma(self, sigma):
|
||||
self.sigma = sigma
|
||||
|
||||
def set_n_components(self, n_components):
|
||||
self.n_components = n_components
|
||||
|
||||
def set_pca_mat_path(self, pca_mat_path):
|
||||
self.pca_mat_path = pca_mat_path
|
||||
|
||||
def set_weight_load_dir(self, weight_load_dir):
|
||||
self.weight_load_dir = weight_load_dir
|
||||
|
||||
def _get_configuration(self, level, boost_config_dict):
|
||||
"""Get configuration."""
|
||||
level_config = _boost_config_level[level]
|
||||
|
@ -229,7 +288,6 @@ class AutoBoost:
|
|||
self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s])
|
||||
return level_config
|
||||
|
||||
|
||||
_boost_config_func_map = {
|
||||
"fn_flag": set_fn_flag,
|
||||
"gc_flag": set_gc_flag,
|
||||
|
@ -239,5 +297,13 @@ class AutoBoost:
|
|||
"total_steps": set_total_steps,
|
||||
"device_number": set_device_number,
|
||||
"gradient_split_groups": set_gradient_split_groups,
|
||||
"grad_accumulation_step": set_grad_accumulation_step
|
||||
"grad_accumulation_step": set_grad_accumulation_step,
|
||||
"ls_weight_decay": set_ls_weight_decay,
|
||||
"rho": set_rho,
|
||||
"gamma": set_gamma,
|
||||
"alpha": set_alpha,
|
||||
"sigma": set_sigma,
|
||||
"n_components": set_n_components,
|
||||
"pca_mat_path": set_pca_mat_path,
|
||||
"weight_load_dir": set_weight_load_dir,
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean
|
||||
from mindspore.communication.management import get_group_size, create_group
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import Tensor, RowTensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -29,7 +29,9 @@ from mindspore.common import dtype as mstype
|
|||
from .boost import AutoBoost
|
||||
from .grad_freeze import FreezeOpt, freeze_cell
|
||||
from .adasum import AdaSum
|
||||
from .dim_reduce import DimReduce
|
||||
from .grad_accumulation import gradient_accumulation_op, gradient_clear_op
|
||||
from .base import load_local_pca_mat
|
||||
|
||||
|
||||
__all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"]
|
||||
|
@ -156,6 +158,23 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
if self.use_grad_accumulation:
|
||||
self.grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
|
||||
|
||||
self.enable_dim_reduce = self.check_dim_reduce_enable()
|
||||
if self.enable_dim_reduce:
|
||||
local_pca_mat_path = auto_boost.local_pca_mat_path
|
||||
rho = auto_boost.rho
|
||||
ls_weight_decay = auto_boost.ls_weight_decay
|
||||
gamma = auto_boost.gamma
|
||||
alpha = auto_boost.alpha
|
||||
sigma = auto_boost.sigma
|
||||
_rank = _get_global_rank()
|
||||
_rank_size = get_group_size()
|
||||
_device_number = auto_boost.device_number
|
||||
n_components = auto_boost.n_components
|
||||
pca_mat = load_local_pca_mat(local_pca_mat_path, n_components)
|
||||
self.weights_clone = ParameterTuple(self.weights).clone(prefix="weights_clone", init="same")
|
||||
self.dim_reduce = DimReduce(self.network, self.optimizer, self.weights, pca_mat, n_components, rho,
|
||||
ls_weight_decay, gamma, alpha, sigma, _rank, _rank_size)
|
||||
|
||||
self.freeze_nets = None
|
||||
self.step = Parameter(Tensor(0, dtype=mstype.int32))
|
||||
if self.freeze:
|
||||
|
@ -169,7 +188,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
self.grad, self.use_grad_accumulation, self.mean, self.degree,
|
||||
self.max_accumulation_step)
|
||||
|
||||
self.enable_adasum = self.check_adasum_enable(optimizer, self.reducer_flag)
|
||||
self.enable_adasum = self.check_adasum_enable()
|
||||
self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32))
|
||||
if self.enable_adasum:
|
||||
_rank = _get_global_rank()
|
||||
|
@ -204,9 +223,11 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.gradient_accumulation_process(loss, grads)
|
||||
loss = self.gradient_accumulation_process(loss, grads, *inputs)
|
||||
else:
|
||||
if self.enable_adasum:
|
||||
if self.enable_dim_reduce:
|
||||
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
|
||||
elif self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, grads))
|
||||
else:
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
|
@ -235,7 +256,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
self.step += 1
|
||||
return loss
|
||||
|
||||
def gradient_accumulation_process(self, loss, grads):
|
||||
def gradient_accumulation_process(self, loss, grads, *inputs):
|
||||
r"""
|
||||
Gradient accumulation algorithm process.
|
||||
|
||||
|
@ -251,7 +272,10 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
self.accumulation_step += 1
|
||||
|
||||
if self.accumulation_step >= self.max_accumulation_step:
|
||||
if self.enable_adasum:
|
||||
if self.enable_dim_reduce:
|
||||
loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, self.weights, self.weights_clone,
|
||||
*inputs))
|
||||
elif self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
|
||||
else:
|
||||
loss = F.depend(loss, self.optimizer(self.grad_accumulation))
|
||||
|
@ -290,15 +314,11 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
self.hyper_map(F.partial(_save_weight), weight_tuple, update_weights)
|
||||
return loss
|
||||
|
||||
def check_adasum_enable(self, optimizer, reducer_flag):
|
||||
def check_adasum_enable(self):
|
||||
r"""
|
||||
Check adasum enable.
|
||||
|
||||
Args:
|
||||
optimizer (Union[Cell]): Optimizer for updating the weights.
|
||||
reducer_flag (bool): Reducer flag.
|
||||
"""
|
||||
if not getattr(optimizer, "adasum", None) or not reducer_flag:
|
||||
if not getattr(self.optimizer, "adasum", None) or not self.reducer_flag:
|
||||
return False
|
||||
_rank_size = get_group_size()
|
||||
_device_number = 8
|
||||
|
@ -306,6 +326,14 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
is_enable = bool(group_number > 1 and group_number & (group_number - 1) == 0)
|
||||
return is_enable
|
||||
|
||||
def check_dim_reduce_enable(self):
|
||||
r"""
|
||||
Check dim_reduce enable.
|
||||
"""
|
||||
if not getattr(self.optimizer, "dim_reduce", None):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
||||
r"""
|
||||
|
@ -421,10 +449,12 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|||
# if there is no overflow, do optimize
|
||||
if not overflow:
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.gradient_accumulation_process(loss, grads)
|
||||
loss = self.gradient_accumulation_process(loss, grads, *inputs)
|
||||
else:
|
||||
if self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
|
||||
if self.enable_dim_reduce:
|
||||
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
|
||||
elif self.enable_adasum:
|
||||
loss = F.depend(loss, self.adasum_process(loss, grads))
|
||||
else:
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
return loss, cond, scaling_sens
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dim_reduce"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
__all__ = ["DimReduce"]
|
||||
|
||||
|
||||
_save_weight = C.MultitypeFuncGraph("_save_weight")
|
||||
|
||||
|
||||
@_save_weight.register("Tensor", "Tensor")
|
||||
def _save_weight_process(parameter, new_parameter):
|
||||
return P.Assign()(parameter, new_parameter)
|
||||
|
||||
|
||||
_pca_projection = C.MultitypeFuncGraph("_pca_projection")
|
||||
|
||||
|
||||
@_pca_projection.register("Tensor", "Tensor")
|
||||
def _pca_projection_process(pca_mat, grad):
|
||||
grad_k = P.MatMul()(pca_mat, F.reshape(grad, (-1, 1)))
|
||||
return grad_k
|
||||
|
||||
|
||||
_pca_back_projection = C.MultitypeFuncGraph("_pca_back_projection")
|
||||
|
||||
|
||||
@_pca_back_projection.register("Tensor", "Tensor", "Tensor")
|
||||
def _pca_back_projection_process(grad_k, pca_mat, grad):
|
||||
grad_proj = P.MatMul()(F.transpose(pca_mat, (1, 0)), grad_k)
|
||||
grad_proj_reshape = F.reshape(grad_proj, F.shape(grad))
|
||||
return grad_proj_reshape
|
||||
|
||||
|
||||
_update_grad_res_momentum = C.MultitypeFuncGraph("_update_grad_res_momentum")
|
||||
|
||||
|
||||
@_update_grad_res_momentum.register("Float32", "Float32", "Tensor", "Tensor", "Tensor")
|
||||
def _update_grad_res_momentum_process(gamma, alpha, grad_res_momentum, grad, grad_proj):
|
||||
grad_res_momentum_new = gamma * grad_res_momentum + grad - grad_proj
|
||||
P.Assign()(grad_res_momentum, grad_res_momentum_new)
|
||||
res = alpha * grad_res_momentum_new
|
||||
return res
|
||||
|
||||
|
||||
_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight")
|
||||
|
||||
|
||||
@_get_delta_weight.register("Tensor", "Tensor", "Tensor")
|
||||
def _get_delta_weight_process(rho, dn, grad_res_momentum):
|
||||
delta_weight = grad_res_momentum - rho * dn
|
||||
return delta_weight
|
||||
|
||||
|
||||
class DimReduce(Cell):
|
||||
r"""
|
||||
The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. The network only supports single output.
|
||||
optimizer (Union[Cell]): Optimizer for updating the weights.
|
||||
weight (Tuple(Parameter)): Tuple of parameters.
|
||||
pca_mat_local (numpy.ndarray): For PCA operation, k*n, k is part of n_components, n is the size of weight.
|
||||
n_components (int): PCA.componets_.
|
||||
rho (float): Apply to grad.
|
||||
ls_weight_decay (float): Apply to l2loss.
|
||||
gamma (float): Apply to grad.
|
||||
alpha (float): Apply to grad.
|
||||
sigma (float): Apply to loss.
|
||||
rank (int): Rank number.
|
||||
rank_size (int): Rank size.
|
||||
|
||||
Inputs:
|
||||
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
||||
- **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors.
|
||||
- **weight** (Tuple(Tensor)) - Tuple of parameters.
|
||||
- **weight_clone** (Tuple(Tensor)) - clone of weight
|
||||
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
|
||||
Outputs:
|
||||
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
||||
"""
|
||||
def __init__(self, network, optimizer, weight, pca_mat_local, n_components, rho, ls_weight_decay, gamma, alpha,
|
||||
sigma, rank, rank_size):
|
||||
super(DimReduce, self).__init__()
|
||||
self.network = network
|
||||
self.optimizer = optimizer
|
||||
self.rank = rank
|
||||
self.rank_size = rank_size
|
||||
self.ls_weight_decay = ls_weight_decay
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.sigma = sigma
|
||||
|
||||
self.float_type = mstype.float32
|
||||
self._set_rho_list(rho)
|
||||
self._set_local_pca_mat(pca_mat_local, n_components, weight)
|
||||
self._set_init_parameter(weight)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.allreduce = P.AllReduce()
|
||||
self.allgather = P.AllGather()
|
||||
self.concat = P.Concat()
|
||||
self.matmul = P.MatMul()
|
||||
self.mul = P.Mul()
|
||||
self.l2loss = P.L2Loss()
|
||||
self.add = P.Add()
|
||||
|
||||
def _set_rho_list(self, rho):
|
||||
"""set rho list info."""
|
||||
self.max_search_time = 3
|
||||
self.rho_list = []
|
||||
for i in range(self.max_search_time):
|
||||
self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type))
|
||||
|
||||
def _set_local_pca_mat(self, pca_mat_local, n_components, parameter_tuple):
|
||||
"""set pca info."""
|
||||
self.n_components = n_components
|
||||
local_dim = math.ceil(self.n_components / self.rank_size)
|
||||
|
||||
self.start_index = self.rank * local_dim
|
||||
self.end_index = (self.rank + 1) * local_dim
|
||||
|
||||
start = 0
|
||||
self.pca_list_local = ()
|
||||
for param in parameter_tuple:
|
||||
size = np.shape(param.asnumpy().reshape((-1, 1)))[0]
|
||||
self.pca_list_local += (Tensor(pca_mat_local[:, start:start + size], dtype=self.float_type),)
|
||||
start += size
|
||||
|
||||
self.dk_pad_flag = False
|
||||
pad_num = self.rank_size * local_dim - self.n_components
|
||||
if pad_num:
|
||||
self.dk_pad_flag = True
|
||||
self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type)
|
||||
|
||||
self.broadcast_list = []
|
||||
pca_rank_num = math.ceil(self.n_components / local_dim)
|
||||
for i in range(pca_rank_num):
|
||||
broadcast = P.Broadcast(i)
|
||||
self.broadcast_list.append(broadcast)
|
||||
|
||||
def _set_init_parameter(self, parameter_tuple):
|
||||
"""init parameters."""
|
||||
self.true_flag = Tensor(True)
|
||||
self.false_flag = Tensor(False)
|
||||
self.epsilon = np.power(10.0, -20)
|
||||
self.gk_last = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="gk_last")
|
||||
self.gk_last_init = Parameter(Tensor(False), name="gk_last_init")
|
||||
self.bk = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk")
|
||||
self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk")
|
||||
self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type)
|
||||
self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros")
|
||||
|
||||
def construct(self, loss, old_grad, weight, weight_clone, *inputs):
|
||||
weight = F.depend(weight, loss)
|
||||
l2_loss = self._get_l2_loss(weight)
|
||||
old_loss = self.allreduce(loss) / self.rank_size + l2_loss
|
||||
|
||||
gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad)
|
||||
gk_local = F.addn(gk_local)
|
||||
gk_pad = self.allgather(gk_local)
|
||||
gk_pad = F.reshape(gk_pad, (-1, 1))
|
||||
gk = gk_pad[0:self.n_components, :]
|
||||
|
||||
dk = self._apply_quasi_newton_update(gk)
|
||||
if self.dk_pad_flag:
|
||||
dk_pad = self.concat((dk, self.dk_pad_part))
|
||||
else:
|
||||
dk_pad = dk
|
||||
dk_local = dk_pad[self.start_index: self.end_index, :]
|
||||
|
||||
dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad)
|
||||
grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad)
|
||||
dn = dn_local
|
||||
grad_proj = grad_proj_local
|
||||
for broadcast in self.broadcast_list:
|
||||
dn_part = broadcast(dn_local)
|
||||
dn = self.hyper_map(self.add, dn, dn_part)
|
||||
grad_proj_part = broadcast(grad_proj_local)
|
||||
grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part)
|
||||
|
||||
rho = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
|
||||
update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha),
|
||||
self.grad_res_momentum, old_grad, grad_proj)
|
||||
delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad)
|
||||
update = self.optimizer(delta_weight)
|
||||
weight = F.depend(weight, update)
|
||||
clone = self.hyper_map(_save_weight, weight_clone, weight)
|
||||
loss = F.depend(loss, clone)
|
||||
return loss
|
||||
|
||||
def _line_search(self, gk, dk, dn, old_loss, weight, weight_clone, *inputs):
|
||||
"""line search rho."""
|
||||
res = self.rho_list[self.max_search_time - 1]
|
||||
for i in range(self.max_search_time):
|
||||
find = self._find_rho(gk, dk, dn, old_loss, weight, weight_clone, self.rho_list[i], *inputs)
|
||||
if find:
|
||||
res = self.rho_list[i]
|
||||
break
|
||||
return res
|
||||
|
||||
def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs):
|
||||
"""search rho."""
|
||||
res = self.false_flag
|
||||
sn = self.hyper_map(F.partial(self.mul, -1 * rho), dn)
|
||||
sn = F.depend(sn, old_loss)
|
||||
update = self.optimizer(sn)
|
||||
new_loss = self.network(*inputs)
|
||||
new_loss = self.allreduce(new_loss)
|
||||
weight = F.depend(weight, update)
|
||||
new_l2_loss = self._get_l2_loss(weight)
|
||||
new_loss = new_loss / self.rank_size + new_l2_loss
|
||||
old_loss_delta = old_loss + self.sigma * rho * F.squeeze(self.matmul(F.transpose(gk, (1, 0)), dk))
|
||||
if old_loss_delta > new_loss:
|
||||
_save_weight(self.sk, rho * dk)
|
||||
res = self.true_flag
|
||||
weight_clone = F.depend(weight_clone, old_loss_delta)
|
||||
restore = self.hyper_map(_save_weight, weight, weight_clone)
|
||||
res = F.depend(res, restore)
|
||||
return res
|
||||
|
||||
def _apply_quasi_newton_update(self, gk):
|
||||
"""apply quasi_newton update."""
|
||||
if self.gk_last_init:
|
||||
yk = gk - self.gk_last
|
||||
g = self.matmul(F.transpose(yk, (1, 0)), self.sk)
|
||||
g = F.squeeze(g)
|
||||
if g > self.epsilon:
|
||||
pk = 1. / g
|
||||
t1 = self.eye - self.matmul(pk * yk, F.transpose(self.sk, (1, 0)))
|
||||
new_bk = self.matmul(self.matmul(F.transpose(t1, (1, 0)), self.bk), t1) + \
|
||||
self.matmul(pk * self.sk, F.transpose(self.sk, (1, 0)))
|
||||
_save_weight(self.bk, new_bk)
|
||||
else:
|
||||
_save_weight(self.gk_last_init, self.true_flag)
|
||||
_save_weight(self.gk_last, gk)
|
||||
dk = -1 * self.matmul(self.bk, gk)
|
||||
return dk
|
||||
|
||||
def _get_l2_loss(self, weight):
|
||||
"""get l2 loss."""
|
||||
l2_loss = self.hyper_map(self.l2loss, weight)
|
||||
l2_loss = F.addn(l2_loss)
|
||||
l2_loss *= self.ls_weight_decay
|
||||
return l2_loss
|
|
@ -60,7 +60,7 @@ class GradientAccumulation(Cell):
|
|||
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
|
||||
|
||||
def construct(self, loss, grads):
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self._max_accumulation_step),
|
||||
self._grad_accumulation, grads))
|
||||
self._accumulation_step += 1
|
||||
|
||||
|
@ -69,6 +69,6 @@ class GradientAccumulation(Cell):
|
|||
self._accumulation_step = 0
|
||||
|
||||
if self._accumulation_step == 0:
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
|
||||
|
||||
return loss
|
||||
|
|
|
@ -277,7 +277,7 @@ def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulat
|
|||
reducer_flag (bool): Reducer flag.
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Tensor): Tensor with shape :math:`()`
|
||||
sens (numbers.Number): The scaling number.
|
||||
grad (tuple(Tensor)): Tuple of gradient tensors.
|
||||
use_grad_accumulation (bool): Use gradient accumulation flag.
|
||||
mean (bool): Gradients mean flag. default: None.
|
||||
|
|
|
@ -185,6 +185,7 @@ class SGD(Optimizer):
|
|||
params = self.parameters
|
||||
accum = self.accum
|
||||
stat = self.stat
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils of auto parallel"""
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context, log as logger
|
||||
from mindspore.context import ParallelMode
|
||||
|
|
|
@ -110,6 +110,7 @@ def _check_level(level, boost_level):
|
|||
|
||||
return level, enable_boost
|
||||
|
||||
|
||||
def _add_loss_network(network, loss_fn, cast_model_type):
|
||||
"""Add loss network."""
|
||||
|
||||
|
|
|
@ -193,8 +193,8 @@ class Model:
|
|||
>>> model.train(2, dataset)
|
||||
"""
|
||||
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
|
||||
eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs):
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
|
||||
amp_level="O0", boost_level="O0", **kwargs):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
|
|
Loading…
Reference in New Issue