Fix Boost docs bug.

This commit is contained in:
linqingke 2021-10-27 17:45:47 +08:00
parent 360e228f45
commit b5d54b161c
6 changed files with 144 additions and 25 deletions

View File

@ -150,9 +150,10 @@ class AdaSum(Cell):
parallel training of Deep Learning models. parallel training of Deep Learning models.
Args: Args:
network (Cell): The training network. The network only supports single output. rank (int): Rank number.
optimizer (Union[Cell]): Optimizer for updating the weights. device_number (int): Device number.
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. group_number (int): Group number.
parameter_tuple (Tuple(Parameter)): Tuple of parameters.
Inputs: Inputs:
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients. - **delta_weights** (Tuple(Tensor)) - Tuple of gradients.

View File

@ -25,7 +25,7 @@ __all__ = ["OptimizerProcess", "ParameterProcess"]
class OptimizerProcess: class OptimizerProcess:
""" r"""
Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags
and creating new optimizers. and creating new optimizers.
@ -68,7 +68,12 @@ class OptimizerProcess:
self.origin_params = opt.init_params["params"] self.origin_params = opt.init_params["params"]
def build_params_dict(self, network): def build_params_dict(self, network):
"""Build the params dict of the network""" r"""
Build the parameter's dict of the network.
Inputs:
- **network** (Cell) - The training network.
"""
cells = network.cells_and_names() cells = network.cells_and_names()
params_dict = {} params_dict = {}
for _, cell in cells: for _, cell in cells:
@ -77,7 +82,13 @@ class OptimizerProcess:
return params_dict return params_dict
def build_gc_params_group(self, params_dict, parameters): def build_gc_params_group(self, params_dict, parameters):
"""Build the params group that needs gc""" r"""
Build the parameter's group with grad centralization.
Inputs:
- **params_dict** (dict) - The network's parameter dict.
- **parameters** (list) - The network's parameter list.
"""
group_params = [] group_params = []
for group_param in parameters: for group_param in parameters:
if 'order_params' in group_param.keys(): if 'order_params' in group_param.keys():
@ -107,7 +118,12 @@ class OptimizerProcess:
return group_params return group_params
def add_grad_centralization(self, network): def add_grad_centralization(self, network):
"""Add gradient centralization.""" r"""
Add gradient centralization.
Inputs:
- **network** (Cell) - The training network.
"""
params_dict = self.build_params_dict(network) params_dict = self.build_params_dict(network)
parameters = self.origin_params parameters = self.origin_params
@ -137,7 +153,7 @@ class OptimizerProcess:
class ParameterProcess: class ParameterProcess:
""" r"""
Process parameter for Boost. Currently, this class supports creating group parameters Process parameter for Boost. Currently, this class supports creating group parameters
and automatically setting gradient segmentation point. and automatically setting gradient segmentation point.
@ -171,7 +187,13 @@ class ParameterProcess:
self._parameter_indices = 1 self._parameter_indices = 1
def assign_parameter_group(self, parameters, split_point=None): def assign_parameter_group(self, parameters, split_point=None):
"""Assign parameter group.""" r"""
Assign parameter group.
Inputs:
- **parameters** (list) - The network's parameter list.
- **split_point** (list) - The gradient split point of this network. default: None.
"""
if not isinstance(parameters, (list, tuple)) or not parameters: if not isinstance(parameters, (list, tuple)) or not parameters:
return parameters return parameters
@ -187,7 +209,13 @@ class ParameterProcess:
return parameters return parameters
def generate_group_params(self, parameters, origin_params): def generate_group_params(self, parameters, origin_params):
"""Generate group parameters.""" r"""
Generate group parameters.
Inputs:
- **parameters** (list) - The network's parameter list.
- **origin_params** (list) - The network's origin parameter list.
"""
origin_params_copy = origin_params origin_params_copy = origin_params
if origin_params_copy is not None: if origin_params_copy is not None:
if not isinstance(origin_params_copy, list): if not isinstance(origin_params_copy, list):

View File

@ -37,7 +37,7 @@ _boost_config_level = {
class AutoBoost: class AutoBoost:
""" r"""
Provide auto accelerating for network. Provide auto accelerating for network.
Args: Args:
@ -68,7 +68,13 @@ class AutoBoost:
self._boost_config_func_map[key](self, val) self._boost_config_func_map[key](self, val)
def network_auto_process_train(self, network, optimizer): def network_auto_process_train(self, network, optimizer):
"""Network train.""" r"""
Boost network train.
Inputs:
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
"""
if self._boost_config["less_bn"]: if self._boost_config["less_bn"]:
network = LessBN(network, fn_flag=self._fn_flag) network = LessBN(network, fn_flag=self._fn_flag)
optimizer_process = OptimizerProcess(optimizer) optimizer_process = OptimizerProcess(optimizer)
@ -90,7 +96,12 @@ class AutoBoost:
return network, optimizer return network, optimizer
def network_auto_process_eval(self, network): def network_auto_process_eval(self, network):
"""Network eval.""" r"""
Boost network eval.
Args:
- **network** (Cell) - The inference network.
"""
if self._boost_config["less_bn"]: if self._boost_config["less_bn"]:
network = LessBN(network) network = LessBN(network)

View File

@ -209,7 +209,12 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss return loss
def gradient_freeze_process(self, *inputs): def gradient_freeze_process(self, *inputs):
"""gradient freeze algorithm process.""" r"""
Gradient freeze algorithm process.
Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
"""
if self.train_strategy is None: if self.train_strategy is None:
step = self.step step = self.step
max_index = len(self.freeze_nets) max_index = len(self.freeze_nets)
@ -224,7 +229,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss return loss
def gradient_accumulation_process(self, loss, grads): def gradient_accumulation_process(self, loss, grads):
"""gradient accumulation algorithm process.""" r"""
Gradient accumulation algorithm process.
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **grads** (Tuple(Tensor)) - Tuple of gradient tensors.
"""
loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step), loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step),
self.grad_accumulation, grads)) self.grad_accumulation, grads))
self.accumulation_step += 1 self.accumulation_step += 1
@ -242,7 +253,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss return loss
def adasum_process(self, loss, grads): def adasum_process(self, loss, grads):
"""adasum algorithm process.""" r"""
Adasum algorithm process.
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **grads** (Tuple(Tensor)) - Tuple of gradient tensors.
"""
loss = F.depend(loss, self.optimizer(grads)) loss = F.depend(loss, self.optimizer(grads))
rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]] rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
grad_clone = F.depend(self.grad_clone, loss) grad_clone = F.depend(self.grad_clone, loss)
@ -261,7 +278,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss return loss
def check_adasum_enable(self, optimizer, reducer_flag): def check_adasum_enable(self, optimizer, reducer_flag):
"""check adasum enable.""" r"""
Check adasum enable.
Inputs:
- **optimizer** (Union[Cell]) - Optimizer for updating the weights.
- **reducer_flag** (bool) - Reducer flag.
"""
if not getattr(optimizer, "adasum", None) or not reducer_flag: if not getattr(optimizer, "adasum", None) or not reducer_flag:
return False return False
_rank_size = get_group_size() _rank_size = get_group_size()
@ -280,7 +303,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
must be provied by Cell type of `scale_sense`. must be provide by Cell type of `scale_sense`.
Args: Args:
network (Cell): The training network. The network only supports single output. network (Cell): The training network. The network only supports single output.

View File

@ -114,7 +114,7 @@ class FreezeOpt(Cell):
class _TrainFreezeCell(Cell): class _TrainFreezeCell(Cell):
""" r"""
Gradient freezing training network. Gradient freezing training network.
Args: Args:
@ -157,7 +157,7 @@ class _TrainFreezeCell(Cell):
class GradientFreeze: class GradientFreeze:
""" r"""
Freezing the gradients of some layers randomly. The number and Freezing the gradients of some layers randomly. The number and
probability of frozen layers can be configured by users probability of frozen layers can be configured by users
@ -180,7 +180,13 @@ class GradientFreeze:
self._param_processer = ParameterProcess() self._param_processer = ParameterProcess()
def split_parameters_groups(self, net, freeze_para_groups_number): def split_parameters_groups(self, net, freeze_para_groups_number):
"""Split parameter groups for gradients freezing training.""" r"""
Split parameter groups for gradients freezing training.
Inputs:
- **net** (Cell) - The training network.
- **freeze_para_groups_number** (int) - The number of gradient freeze groups.
"""
grouped_params = [] grouped_params = []
tmp = [] tmp = []
for para in net.trainable_params(): for para in net.trainable_params():
@ -201,7 +207,15 @@ class GradientFreeze:
return freeze_grouped_params return freeze_grouped_params
def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps): def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps):
"""Generate index sequence for gradient freezing training.""" r"""
Generate index sequence for gradient freezing training.
Inputs:
- **parameter_groups_number** (int) - The number of parameter groups.
- **freeze_strategy** (int) - Gradient freeze grouping strategy, select from [0, 1].
- **freeze_p** (float) - Gradient freezing probability.
- **total_steps** (int) - Total training steps.
"""
total_step = int(total_steps * 1.01) total_step = int(total_steps * 1.01)
if parameter_groups_number <= 1: if parameter_groups_number <= 1:
return [0 for _ in range(total_step)] return [0 for _ in range(total_step)]
@ -235,7 +249,13 @@ class GradientFreeze:
f"Unsupported freezing training strategy '{freeze_strategy}'") f"Unsupported freezing training strategy '{freeze_strategy}'")
def freeze_generate(self, network, optimizer): def freeze_generate(self, network, optimizer):
"""Generate freeze network and optimizer.""" r"""
Generate freeze network and optimizer.
Inputs:
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
"""
train_para_groups = self.split_parameters_groups( train_para_groups = self.split_parameters_groups(
network, self._param_groups) network, self._param_groups)
for i in range(self._param_groups): for i in range(self._param_groups):
@ -250,7 +270,43 @@ class GradientFreeze:
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None, def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
max_accumulation_step=1): max_accumulation_step=1):
"""Provide freeze network cell.""" r"""
Generate freeze network and optimizer.
Inputs:
- **reducer_flag** (bool) - Reducer flag.
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
- **sens** (Tensor) - Tensor with shape :math:`()`
- **grad** (Tuple(Tensor)) - Tuple of gradient tensors.
- **use_grad_accumulation** (bool) - Use gradient accumulation flag.
- **mean** (bool) - Gradients mean flag. default: None.
- **degree** (int) - Device number. default: None.
- **max_accumulation_step** (int) - Max accumulation steps. default: 1.
Examples:
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> import mindspore.ops as ops
>>> from mindspore.boost.grad_freeze import freeze_cell
>>>
>>> class Net(nn.Cell):
... def __init__(self, in_features, out_features):
... super(Net, self).__init__()
... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
... name='weight')
... self.matmul = ops.MatMul()
...
... def construct(self, x):
... output = self.matmul(x, self.weight)
... return output
...
>>> in_features, out_features = 16, 10
>>> network = Net(in_features, out_features)
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> grad = ops.GradOperation(get_by_list=True, sens_param=True)
>>> freeze_nets = freeze_cell(False, network, optimizer, 1.0, grad, False, None, None, 1)
"""
if reducer_flag: if reducer_flag:
param_processer = ParameterProcess() param_processer = ParameterProcess()
grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters), grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters),

View File

@ -226,7 +226,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
must be provied by Cell type of `scale_sense`. must be provide by Cell type of `scale_sense`.
Args: Args:
network (Cell): The training network. The network only supports single output. network (Cell): The training network. The network only supports single output.