add dimesion reduce training

This commit is contained in:
jinjiali 2021-12-03 15:37:52 +08:00
parent 600c6421f1
commit b96eba92b4
12 changed files with 552 additions and 38 deletions

View File

@ -26,6 +26,7 @@ from .less_batch_normalization import LessBN
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
from .grad_accumulation import GradientAccumulation from .grad_accumulation import GradientAccumulation
from .adasum import AdaSum from .adasum import AdaSum
from .dim_reduce import DimReduce
__all__ = ['AutoBoost', __all__ = ['AutoBoost',
@ -34,4 +35,4 @@ __all__ = ['AutoBoost',
'LessBN', 'LessBN',
'GradientFreeze', 'FreezeOpt', 'freeze_cell', 'GradientFreeze', 'FreezeOpt', 'freeze_cell',
'GradientAccumulation', 'GradientAccumulation',
'AdaSum'] 'AdaSum', 'DimReduce']

View File

@ -22,7 +22,6 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive from mindspore.ops.operations._inner_ops import Send, Receive
from mindspore.common.tensor import Tensor
__all__ = ["AdaSum"] __all__ = ["AdaSum"]
@ -57,8 +56,6 @@ def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisi
"""send result and receive result.""" """send result and receive result."""
if parameter_divisibility: if parameter_divisibility:
recv_part = P.Squeeze()(recv_part) recv_part = P.Squeeze()(recv_part)
if F.shape(recv_part) is None:
recv_part = Tensor([recv_part])
local_part = F.depend(local_part, recv_part) local_part = F.depend(local_part, recv_part)
eps = 1e-12 eps = 1e-12
value_0 = P.ReduceSum()(local_part * recv_part) + eps value_0 = P.ReduceSum()(local_part * recv_part) + eps
@ -128,10 +125,6 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
recv_part = _receive_before_send(delta_w, send, recv) recv_part = _receive_before_send(delta_w, send, recv)
recv_part = P.Squeeze()(recv_part) recv_part = P.Squeeze()(recv_part)
if F.shape(recv_part) is None:
recv_part = Tensor([recv_part])
if F.shape(delta_w) is None:
delta_w = Tensor([delta_w])
recv_part = P.Reshape()(recv_part, (-1,)) recv_part = P.Reshape()(recv_part, (-1,))
delta_w = P.Reshape()(delta_w, (-1,)) delta_w = P.Reshape()(delta_w, (-1,))

View File

@ -13,15 +13,23 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""base process""" """base process"""
import os
import time
import math
import copy import copy
import numpy as np
from scipy import linalg as la
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.optim import LARS from mindspore.nn.optim import LARS
from mindspore import log as logger from mindspore import log as logger
from mindspore.common import Parameter from mindspore.common import Parameter
from mindspore.communication.management import get_group_size
from mindspore.parallel._utils import _get_global_rank
from mindspore.train.serialization import load_checkpoint
from .less_batch_normalization import CommonHeadLastFN from .less_batch_normalization import CommonHeadLastFN
__all__ = ["OptimizerProcess", "ParameterProcess"] __all__ = ["OptimizerProcess", "ParameterProcess", "get_local_pca_mat_path", "load_local_pca_mat"]
class OptimizerProcess: class OptimizerProcess:
@ -265,3 +273,150 @@ class ParameterProcess:
else: else:
group_params.append({"params": params_value}) group_params.append({"params": params_value})
return group_params return group_params
def get_local_pca_mat_path(weight_load_dir, pca_mat_path, n_component, device_number):
"""
get local pca mat path.
Args:
weight_load_dir (str): The weight(ckpt) file directory to be load.
pca_mat_path (str): the path to load pca mat. Default: None.
n_component (int): pca component.
device_number (int): device number.
"""
if pca_mat_path is not None and os.path.exists(pca_mat_path) and os.path.isfile(pca_mat_path) and \
pca_mat_path.endswith(".npy"):
full_pca_mat_path = pca_mat_path
pca_mat_exist = True
else:
if weight_load_dir is None or not os.path.exists(weight_load_dir) or not os.path.isdir(weight_load_dir):
raise ValueError("The weight_load_dir: {} is None / not exists / not directory.".format(weight_load_dir))
full_pca_mat_path = os.path.join(weight_load_dir, "pca_mat_temp.npy")
pca_mat_exist = False
rank = _get_global_rank()
local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank) + ".npy"
if os.path.exists(local_pca_mat_path):
os.remove(local_pca_mat_path)
if rank % device_number != 0:
return local_pca_mat_path
if pca_mat_exist:
pca_mat = np.load(full_pca_mat_path)
else:
data = _load_weights(weight_load_dir)
pca_mat = _compute_pca_mat(data, n_component)
np.save(full_pca_mat_path, pca_mat)
_save_local_pca_mat(pca_mat, full_pca_mat_path, n_component)
return local_pca_mat_path
def _load_weights(weight_load_dir):
"""
load weights.
Args:
weight_load_dir (str): The weight(ckpt) file directory to be load.
"""
param_mat = None
weight_file_list = os.listdir(weight_load_dir)
for file in weight_file_list:
if not file.endswith('.ckpt'):
continue
file_path = os.path.join(weight_load_dir, file)
param_dict = load_checkpoint(file_path)
param = None
for _, value in param_dict.items():
if param is None:
param = value.asnumpy().reshape((1, -1))
else:
param = np.hstack((param, value.asnumpy().reshape((1, -1))))
if param_mat is None:
param_mat = param
else:
param_mat = np.vstack((param_mat, param))
return param_mat
def _compute_pca_mat(data, n_component):
"""
compute pca mat.
Args:
data : array-like of shape (n_samples, n_features)
Training data, where `n_samples` is the number of samples
and `n_features` is the number of features.
n_component (int): pca component.
"""
if data.shape[0] < n_component:
raise ValueError("The samples: {} is less than: n_component {}.".format(data.shape[0], n_component))
mean = np.mean(data, axis=0)
data -= mean
u, _, v = la.svd(data, full_matrices=False)
_, v = _svd_flip(u, v)
components = v[:n_component]
return components
def _svd_flip(u, v):
"""
svd flip.
Args:
u (ndarray): the output of `linalg.svd`.
v (ndarray): the output of `linalg.svd`.
"""
max_abs_cols = np.argmax(np.abs(u), axis=0)
signs = np.sign(u[max_abs_cols, range(u.shape[1])])
u *= signs
v *= signs[:, np.newaxis]
return u, v
def _save_local_pca_mat(pca_mat, full_pca_mat_path, n_component):
"""
save pca mat.
Args:
pca_mat (numpy.ndarray): pca mat to be saved.
full_pca_mat_path (str): the path of full pca mat.
n_component (int): pca component.
"""
rank_size = get_group_size()
local_dim = math.ceil(n_component / rank_size)
for rank_id in range(rank_size):
start_index = rank_id * local_dim
end_index = (rank_id + 1) * local_dim
pca_start_index = min(n_component, start_index)
pca_end_index = min(n_component, end_index)
p_local = np.zeros([local_dim + 1, pca_mat.shape[1]])
if pca_start_index != pca_end_index:
p_local[0: pca_end_index - pca_start_index, :] = pca_mat[pca_start_index: pca_end_index, :]
local_pca_mat_path = full_pca_mat_path[:-4] + "_rank_" + str(rank_id) + ".npy"
np.save(local_pca_mat_path, p_local)
def load_local_pca_mat(local_pca_mat_path, n_component):
"""
load pca mat.
Args:
local_pca_mat_path (str): local pca mat file path.
n_component (int): pca component.
"""
rank_size = get_group_size()
local_dim = math.ceil(n_component / rank_size)
while True:
if os.path.exists(local_pca_mat_path):
break
time.sleep(5)
while True:
pca_mat = np.load(local_pca_mat_path)
if pca_mat.shape[0] == local_dim + 1:
break
time.sleep(5)
pca_mat = pca_mat[:-1, :]
return pca_mat

View File

@ -14,9 +14,11 @@
# ============================================================================ # ============================================================================
"""boost""" """boost"""
import threading import threading
from mindspore.nn.optim import SGD
from .less_batch_normalization import LessBN from .less_batch_normalization import LessBN
from .grad_freeze import GradientFreeze from .grad_freeze import GradientFreeze
from .base import OptimizerProcess, ParameterProcess from .base import OptimizerProcess, ParameterProcess
from .base import get_local_pca_mat_path
__all__ = ["AutoBoost"] __all__ = ["AutoBoost"]
@ -27,17 +29,20 @@ _boost_config_level = {
"less_bn": False, "less_bn": False,
"grad_freeze": False, "grad_freeze": False,
"adasum": False, "adasum": False,
"grad_accumulation": False}, "grad_accumulation": False,
"dim_reduce": False},
"O1": { "O1": {
"less_bn": True, "less_bn": True,
"grad_freeze": True, "grad_freeze": True,
"adasum": False, "adasum": False,
"grad_accumulation": False}, "grad_accumulation": False,
"dim_reduce": False},
"O2": { "O2": {
"less_bn": True, "less_bn": True,
"grad_freeze": True, "grad_freeze": True,
"adasum": True, "adasum": True,
"grad_accumulation": False}} "grad_accumulation": False,
"dim_reduce": False}}
class AutoBoost: class AutoBoost:
@ -57,10 +62,12 @@ class AutoBoost:
"less_bn": false, "less_bn": false,
"grad_freeze": false, "grad_freeze": false,
"adasum": false, "adasum": false,
"grad_accumulation": false "grad_accumulation": false,
"dim_reduce": false
}, },
"common": { "common": {
"gradient_split_groups": [50, 100] "gradient_split_groups": [50, 100],
"device_number": 8
}, },
"less_bn": { "less_bn": {
"fn_flag": true, "fn_flag": true,
@ -73,10 +80,20 @@ class AutoBoost:
"total_steps": 65536 "total_steps": 65536
}, },
"adasum": { "adasum": {
"device_number": 8
}, },
"grad_accumulation": { "grad_accumulation": {
"grad_accumulation_step": 1 "grad_accumulation_step": 1
},
"dim_reduce": {
"ls_weight_decay": 0.0001,
"rho": 0.55,
"gamma": 0.9,
"alpha": 0.001,
"sigma": 0.4,
"n_components": 32,
"pca_mat_path": None,
"weight_load_dir": None
} }
} }
@ -120,6 +137,15 @@ class AutoBoost:
self.gradient_groups = None self.gradient_groups = None
self.device_number = 8 self.device_number = 8
self.grad_accumulation_step = 1 self.grad_accumulation_step = 1
self.ls_weight_decay = 0.0001
self.rho = 0.55
self.gamma = 0.9
self.alpha = 0.001
self.sigma = 0.4
self.n_components = 32
self.pca_mat_path = None
self.weight_load_dir = None
self.local_pca_mat_path = None
self.boost_config = self._get_configuration(level, self.boost_config_dict) self.boost_config = self._get_configuration(level, self.boost_config_dict)
self._param_processer = ParameterProcess() self._param_processer = ParameterProcess()
@ -141,6 +167,13 @@ class AutoBoost:
network (Cell): The training network. network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights. optimizer (Cell): Optimizer for updating the weights.
""" """
if self.boost_config["dim_reduce"]:
self.local_pca_mat_path = get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path,
self.n_components, self.device_number)
optimizer = SGD(network.trainable_params(), learning_rate=1, loss_scale=optimizer.loss_scale)
setattr(optimizer, "dim_reduce", True)
return network, optimizer
if self.boost_config["less_bn"]: if self.boost_config["less_bn"]:
network = LessBN(network, fn_flag=self._fn_flag) network = LessBN(network, fn_flag=self._fn_flag)
optimizer_process = OptimizerProcess(optimizer) optimizer_process = OptimizerProcess(optimizer)
@ -168,6 +201,8 @@ class AutoBoost:
Args: Args:
network (Cell): The inference network. network (Cell): The inference network.
""" """
if self.boost_config["dim_reduce"]:
return network
if self.boost_config["less_bn"]: if self.boost_config["less_bn"]:
network = LessBN(network) network = LessBN(network)
@ -204,6 +239,30 @@ class AutoBoost:
gradient_groups = list(gradient_groups) gradient_groups = list(gradient_groups)
self.gradient_groups = gradient_groups self.gradient_groups = gradient_groups
def set_ls_weight_decay(self, ls_weight_decay):
self.ls_weight_decay = ls_weight_decay
def set_rho(self, rho):
self.rho = rho
def set_gamma(self, gamma):
self.gamma = gamma
def set_alpha(self, alpha):
self.alpha = alpha
def set_sigma(self, sigma):
self.sigma = sigma
def set_n_components(self, n_components):
self.n_components = n_components
def set_pca_mat_path(self, pca_mat_path):
self.pca_mat_path = pca_mat_path
def set_weight_load_dir(self, weight_load_dir):
self.weight_load_dir = weight_load_dir
def _get_configuration(self, level, boost_config_dict): def _get_configuration(self, level, boost_config_dict):
"""Get configuration.""" """Get configuration."""
level_config = _boost_config_level[level] level_config = _boost_config_level[level]
@ -229,7 +288,6 @@ class AutoBoost:
self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s]) self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s])
return level_config return level_config
_boost_config_func_map = { _boost_config_func_map = {
"fn_flag": set_fn_flag, "fn_flag": set_fn_flag,
"gc_flag": set_gc_flag, "gc_flag": set_gc_flag,
@ -239,5 +297,13 @@ class AutoBoost:
"total_steps": set_total_steps, "total_steps": set_total_steps,
"device_number": set_device_number, "device_number": set_device_number,
"gradient_split_groups": set_gradient_split_groups, "gradient_split_groups": set_gradient_split_groups,
"grad_accumulation_step": set_grad_accumulation_step "grad_accumulation_step": set_grad_accumulation_step,
"ls_weight_decay": set_ls_weight_decay,
"rho": set_rho,
"gamma": set_gamma,
"alpha": set_alpha,
"sigma": set_sigma,
"n_components": set_n_components,
"pca_mat_path": set_pca_mat_path,
"weight_load_dir": set_weight_load_dir,
} }

View File

@ -19,7 +19,7 @@ from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean
from mindspore.communication.management import get_group_size, create_group from mindspore.communication.management import get_group_size, create_group
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common import Tensor from mindspore.common import Tensor, RowTensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -29,7 +29,9 @@ from mindspore.common import dtype as mstype
from .boost import AutoBoost from .boost import AutoBoost
from .grad_freeze import FreezeOpt, freeze_cell from .grad_freeze import FreezeOpt, freeze_cell
from .adasum import AdaSum from .adasum import AdaSum
from .dim_reduce import DimReduce
from .grad_accumulation import gradient_accumulation_op, gradient_clear_op from .grad_accumulation import gradient_accumulation_op, gradient_clear_op
from .base import load_local_pca_mat
__all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"] __all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"]
@ -156,6 +158,23 @@ class BoostTrainOneStepCell(TrainOneStepCell):
if self.use_grad_accumulation: if self.use_grad_accumulation:
self.grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros') self.grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self.enable_dim_reduce = self.check_dim_reduce_enable()
if self.enable_dim_reduce:
local_pca_mat_path = auto_boost.local_pca_mat_path
rho = auto_boost.rho
ls_weight_decay = auto_boost.ls_weight_decay
gamma = auto_boost.gamma
alpha = auto_boost.alpha
sigma = auto_boost.sigma
_rank = _get_global_rank()
_rank_size = get_group_size()
_device_number = auto_boost.device_number
n_components = auto_boost.n_components
pca_mat = load_local_pca_mat(local_pca_mat_path, n_components)
self.weights_clone = ParameterTuple(self.weights).clone(prefix="weights_clone", init="same")
self.dim_reduce = DimReduce(self.network, self.optimizer, self.weights, pca_mat, n_components, rho,
ls_weight_decay, gamma, alpha, sigma, _rank, _rank_size)
self.freeze_nets = None self.freeze_nets = None
self.step = Parameter(Tensor(0, dtype=mstype.int32)) self.step = Parameter(Tensor(0, dtype=mstype.int32))
if self.freeze: if self.freeze:
@ -169,7 +188,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.grad, self.use_grad_accumulation, self.mean, self.degree, self.grad, self.use_grad_accumulation, self.mean, self.degree,
self.max_accumulation_step) self.max_accumulation_step)
self.enable_adasum = self.check_adasum_enable(optimizer, self.reducer_flag) self.enable_adasum = self.check_adasum_enable()
self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32)) self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32))
if self.enable_adasum: if self.enable_adasum:
_rank = _get_global_rank() _rank = _get_global_rank()
@ -204,9 +223,11 @@ class BoostTrainOneStepCell(TrainOneStepCell):
grads = self.grad(self.network, self.weights)(*inputs, sens) grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
if self.use_grad_accumulation: if self.use_grad_accumulation:
loss = self.gradient_accumulation_process(loss, grads) loss = self.gradient_accumulation_process(loss, grads, *inputs)
else: else:
if self.enable_adasum: if self.enable_dim_reduce:
loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
elif self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, grads)) loss = F.depend(loss, self.adasum_process(loss, grads))
else: else:
loss = F.depend(loss, self.optimizer(grads)) loss = F.depend(loss, self.optimizer(grads))
@ -235,7 +256,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.step += 1 self.step += 1
return loss return loss
def gradient_accumulation_process(self, loss, grads): def gradient_accumulation_process(self, loss, grads, *inputs):
r""" r"""
Gradient accumulation algorithm process. Gradient accumulation algorithm process.
@ -251,7 +272,10 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.accumulation_step += 1 self.accumulation_step += 1
if self.accumulation_step >= self.max_accumulation_step: if self.accumulation_step >= self.max_accumulation_step:
if self.enable_adasum: if self.enable_dim_reduce:
loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, self.weights, self.weights_clone,
*inputs))
elif self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation)) loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
else: else:
loss = F.depend(loss, self.optimizer(self.grad_accumulation)) loss = F.depend(loss, self.optimizer(self.grad_accumulation))
@ -290,15 +314,11 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.hyper_map(F.partial(_save_weight), weight_tuple, update_weights) self.hyper_map(F.partial(_save_weight), weight_tuple, update_weights)
return loss return loss
def check_adasum_enable(self, optimizer, reducer_flag): def check_adasum_enable(self):
r""" r"""
Check adasum enable. Check adasum enable.
Args:
optimizer (Union[Cell]): Optimizer for updating the weights.
reducer_flag (bool): Reducer flag.
""" """
if not getattr(optimizer, "adasum", None) or not reducer_flag: if not getattr(self.optimizer, "adasum", None) or not self.reducer_flag:
return False return False
_rank_size = get_group_size() _rank_size = get_group_size()
_device_number = 8 _device_number = 8
@ -306,6 +326,14 @@ class BoostTrainOneStepCell(TrainOneStepCell):
is_enable = bool(group_number > 1 and group_number & (group_number - 1) == 0) is_enable = bool(group_number > 1 and group_number & (group_number - 1) == 0)
return is_enable return is_enable
def check_dim_reduce_enable(self):
r"""
Check dim_reduce enable.
"""
if not getattr(self.optimizer, "dim_reduce", None):
return False
return True
class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
r""" r"""
@ -421,10 +449,12 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
# if there is no overflow, do optimize # if there is no overflow, do optimize
if not overflow: if not overflow:
if self.use_grad_accumulation: if self.use_grad_accumulation:
loss = self.gradient_accumulation_process(loss, grads) loss = self.gradient_accumulation_process(loss, grads, *inputs)
else: else:
if self.enable_adasum: if self.enable_dim_reduce:
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation)) loss = F.depend(loss, self.dim_reduce(loss, grads, self.weights, self.weights_clone, *inputs))
elif self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, grads))
else: else:
loss = F.depend(loss, self.optimizer(grads)) loss = F.depend(loss, self.optimizer(grads))
return loss, cond, scaling_sens return loss, cond, scaling_sens

View File

@ -0,0 +1,268 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dim_reduce"""
import math
import numpy as np
from mindspore.nn.cell import Cell
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
__all__ = ["DimReduce"]
_save_weight = C.MultitypeFuncGraph("_save_weight")
@_save_weight.register("Tensor", "Tensor")
def _save_weight_process(parameter, new_parameter):
return P.Assign()(parameter, new_parameter)
_pca_projection = C.MultitypeFuncGraph("_pca_projection")
@_pca_projection.register("Tensor", "Tensor")
def _pca_projection_process(pca_mat, grad):
grad_k = P.MatMul()(pca_mat, F.reshape(grad, (-1, 1)))
return grad_k
_pca_back_projection = C.MultitypeFuncGraph("_pca_back_projection")
@_pca_back_projection.register("Tensor", "Tensor", "Tensor")
def _pca_back_projection_process(grad_k, pca_mat, grad):
grad_proj = P.MatMul()(F.transpose(pca_mat, (1, 0)), grad_k)
grad_proj_reshape = F.reshape(grad_proj, F.shape(grad))
return grad_proj_reshape
_update_grad_res_momentum = C.MultitypeFuncGraph("_update_grad_res_momentum")
@_update_grad_res_momentum.register("Float32", "Float32", "Tensor", "Tensor", "Tensor")
def _update_grad_res_momentum_process(gamma, alpha, grad_res_momentum, grad, grad_proj):
grad_res_momentum_new = gamma * grad_res_momentum + grad - grad_proj
P.Assign()(grad_res_momentum, grad_res_momentum_new)
res = alpha * grad_res_momentum_new
return res
_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight")
@_get_delta_weight.register("Tensor", "Tensor", "Tensor")
def _get_delta_weight_process(rho, dn, grad_res_momentum):
delta_weight = grad_res_momentum - rho * dn
return delta_weight
class DimReduce(Cell):
r"""
The dimension reduce training, is a novel algorithm for accelerating convergence of Deep Learning models.
Args:
network (Cell): The training network. The network only supports single output.
optimizer (Union[Cell]): Optimizer for updating the weights.
weight (Tuple(Parameter)): Tuple of parameters.
pca_mat_local (numpy.ndarray): For PCA operation, k*n, k is part of n_components, n is the size of weight.
n_components (int): PCA.componets_.
rho (float): Apply to grad.
ls_weight_decay (float): Apply to l2loss.
gamma (float): Apply to grad.
alpha (float): Apply to grad.
sigma (float): Apply to loss.
rank (int): Rank number.
rank_size (int): Rank size.
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors.
- **weight** (Tuple(Tensor)) - Tuple of parameters.
- **weight_clone** (Tuple(Tensor)) - clone of weight
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
Outputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, weight, pca_mat_local, n_components, rho, ls_weight_decay, gamma, alpha,
sigma, rank, rank_size):
super(DimReduce, self).__init__()
self.network = network
self.optimizer = optimizer
self.rank = rank
self.rank_size = rank_size
self.ls_weight_decay = ls_weight_decay
self.gamma = gamma
self.alpha = alpha
self.sigma = sigma
self.float_type = mstype.float32
self._set_rho_list(rho)
self._set_local_pca_mat(pca_mat_local, n_components, weight)
self._set_init_parameter(weight)
self.hyper_map = C.HyperMap()
self.allreduce = P.AllReduce()
self.allgather = P.AllGather()
self.concat = P.Concat()
self.matmul = P.MatMul()
self.mul = P.Mul()
self.l2loss = P.L2Loss()
self.add = P.Add()
def _set_rho_list(self, rho):
"""set rho list info."""
self.max_search_time = 3
self.rho_list = []
for i in range(self.max_search_time):
self.rho_list.append(Tensor(np.power(rho, i), dtype=self.float_type))
def _set_local_pca_mat(self, pca_mat_local, n_components, parameter_tuple):
"""set pca info."""
self.n_components = n_components
local_dim = math.ceil(self.n_components / self.rank_size)
self.start_index = self.rank * local_dim
self.end_index = (self.rank + 1) * local_dim
start = 0
self.pca_list_local = ()
for param in parameter_tuple:
size = np.shape(param.asnumpy().reshape((-1, 1)))[0]
self.pca_list_local += (Tensor(pca_mat_local[:, start:start + size], dtype=self.float_type),)
start += size
self.dk_pad_flag = False
pad_num = self.rank_size * local_dim - self.n_components
if pad_num:
self.dk_pad_flag = True
self.dk_pad_part = Tensor(np.zeros([pad_num, 1]), dtype=self.float_type)
self.broadcast_list = []
pca_rank_num = math.ceil(self.n_components / local_dim)
for i in range(pca_rank_num):
broadcast = P.Broadcast(i)
self.broadcast_list.append(broadcast)
def _set_init_parameter(self, parameter_tuple):
"""init parameters."""
self.true_flag = Tensor(True)
self.false_flag = Tensor(False)
self.epsilon = np.power(10.0, -20)
self.gk_last = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="gk_last")
self.gk_last_init = Parameter(Tensor(False), name="gk_last_init")
self.bk = Parameter(Tensor(np.eye(self.n_components), dtype=self.float_type), name="bk")
self.sk = Parameter(Tensor(np.zeros([self.n_components, 1]), dtype=self.float_type), name="sk")
self.eye = Tensor(np.eye(self.n_components), dtype=self.float_type)
self.grad_res_momentum = ParameterTuple(parameter_tuple).clone(prefix="grad_res_momentum", init="zeros")
def construct(self, loss, old_grad, weight, weight_clone, *inputs):
weight = F.depend(weight, loss)
l2_loss = self._get_l2_loss(weight)
old_loss = self.allreduce(loss) / self.rank_size + l2_loss
gk_local = self.hyper_map(_pca_projection, self.pca_list_local, old_grad)
gk_local = F.addn(gk_local)
gk_pad = self.allgather(gk_local)
gk_pad = F.reshape(gk_pad, (-1, 1))
gk = gk_pad[0:self.n_components, :]
dk = self._apply_quasi_newton_update(gk)
if self.dk_pad_flag:
dk_pad = self.concat((dk, self.dk_pad_part))
else:
dk_pad = dk
dk_local = dk_pad[self.start_index: self.end_index, :]
dn_local = self.hyper_map(F.partial(_pca_back_projection, dk_local), self.pca_list_local, old_grad)
grad_proj_local = self.hyper_map(F.partial(_pca_back_projection, gk_local), self.pca_list_local, old_grad)
dn = dn_local
grad_proj = grad_proj_local
for broadcast in self.broadcast_list:
dn_part = broadcast(dn_local)
dn = self.hyper_map(self.add, dn, dn_part)
grad_proj_part = broadcast(grad_proj_local)
grad_proj = self.hyper_map(self.add, grad_proj, grad_proj_part)
rho = self._line_search(gk, dk, dn, old_loss, weight, weight_clone, *inputs)
update_grad = self.hyper_map(F.partial(_update_grad_res_momentum, self.gamma, self.alpha),
self.grad_res_momentum, old_grad, grad_proj)
delta_weight = self.hyper_map(F.partial(_get_delta_weight, rho), dn, update_grad)
update = self.optimizer(delta_weight)
weight = F.depend(weight, update)
clone = self.hyper_map(_save_weight, weight_clone, weight)
loss = F.depend(loss, clone)
return loss
def _line_search(self, gk, dk, dn, old_loss, weight, weight_clone, *inputs):
"""line search rho."""
res = self.rho_list[self.max_search_time - 1]
for i in range(self.max_search_time):
find = self._find_rho(gk, dk, dn, old_loss, weight, weight_clone, self.rho_list[i], *inputs)
if find:
res = self.rho_list[i]
break
return res
def _find_rho(self, gk, dk, dn, old_loss, weight, weight_clone, rho, *inputs):
"""search rho."""
res = self.false_flag
sn = self.hyper_map(F.partial(self.mul, -1 * rho), dn)
sn = F.depend(sn, old_loss)
update = self.optimizer(sn)
new_loss = self.network(*inputs)
new_loss = self.allreduce(new_loss)
weight = F.depend(weight, update)
new_l2_loss = self._get_l2_loss(weight)
new_loss = new_loss / self.rank_size + new_l2_loss
old_loss_delta = old_loss + self.sigma * rho * F.squeeze(self.matmul(F.transpose(gk, (1, 0)), dk))
if old_loss_delta > new_loss:
_save_weight(self.sk, rho * dk)
res = self.true_flag
weight_clone = F.depend(weight_clone, old_loss_delta)
restore = self.hyper_map(_save_weight, weight, weight_clone)
res = F.depend(res, restore)
return res
def _apply_quasi_newton_update(self, gk):
"""apply quasi_newton update."""
if self.gk_last_init:
yk = gk - self.gk_last
g = self.matmul(F.transpose(yk, (1, 0)), self.sk)
g = F.squeeze(g)
if g > self.epsilon:
pk = 1. / g
t1 = self.eye - self.matmul(pk * yk, F.transpose(self.sk, (1, 0)))
new_bk = self.matmul(self.matmul(F.transpose(t1, (1, 0)), self.bk), t1) + \
self.matmul(pk * self.sk, F.transpose(self.sk, (1, 0)))
_save_weight(self.bk, new_bk)
else:
_save_weight(self.gk_last_init, self.true_flag)
_save_weight(self.gk_last, gk)
dk = -1 * self.matmul(self.bk, gk)
return dk
def _get_l2_loss(self, weight):
"""get l2 loss."""
l2_loss = self.hyper_map(self.l2loss, weight)
l2_loss = F.addn(l2_loss)
l2_loss *= self.ls_weight_decay
return l2_loss

View File

@ -60,7 +60,7 @@ class GradientAccumulation(Cell):
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step") self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
def construct(self, loss, grads): def construct(self, loss, grads):
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step), loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self._max_accumulation_step),
self._grad_accumulation, grads)) self._grad_accumulation, grads))
self._accumulation_step += 1 self._accumulation_step += 1
@ -69,6 +69,6 @@ class GradientAccumulation(Cell):
self._accumulation_step = 0 self._accumulation_step = 0
if self._accumulation_step == 0: if self._accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation)) loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
return loss return loss

View File

@ -277,7 +277,7 @@ def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulat
reducer_flag (bool): Reducer flag. reducer_flag (bool): Reducer flag.
network (Cell): The training network. network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights. optimizer (Cell): Optimizer for updating the weights.
sens (Tensor): Tensor with shape :math:`()` sens (numbers.Number): The scaling number.
grad (tuple(Tensor)): Tuple of gradient tensors. grad (tuple(Tensor)): Tuple of gradient tensors.
use_grad_accumulation (bool): Use gradient accumulation flag. use_grad_accumulation (bool): Use gradient accumulation flag.
mean (bool): Gradients mean flag. default: None. mean (bool): Gradients mean flag. default: None.

View File

@ -185,6 +185,7 @@ class SGD(Optimizer):
params = self.parameters params = self.parameters
accum = self.accum accum = self.accum
stat = self.stat stat = self.stat
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients) gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Utils of auto parallel""" """Utils of auto parallel"""
import numpy as np import numpy as np
from mindspore import context, log as logger from mindspore import context, log as logger
from mindspore.context import ParallelMode from mindspore.context import ParallelMode

View File

@ -110,6 +110,7 @@ def _check_level(level, boost_level):
return level, enable_boost return level, enable_boost
def _add_loss_network(network, loss_fn, cast_model_type): def _add_loss_network(network, loss_fn, cast_model_type):
"""Add loss network.""" """Add loss network."""

View File

@ -193,8 +193,8 @@ class Model:
>>> model.train(2, dataset) >>> model.train(2, dataset)
""" """
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs): amp_level="O0", boost_level="O0", **kwargs):
self._network = network self._network = network
self._loss_fn = loss_fn self._loss_fn = loss_fn
self._optimizer = optimizer self._optimizer = optimizer