Fix Boost docs bug.

This commit is contained in:
linqingke 2021-10-27 17:45:47 +08:00
parent 360e228f45
commit b5d54b161c
6 changed files with 144 additions and 25 deletions

View File

@ -150,9 +150,10 @@ class AdaSum(Cell):
parallel training of Deep Learning models.
Args:
network (Cell): The training network. The network only supports single output.
optimizer (Union[Cell]): Optimizer for updating the weights.
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
rank (int): Rank number.
device_number (int): Device number.
group_number (int): Group number.
parameter_tuple (Tuple(Parameter)): Tuple of parameters.
Inputs:
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients.

View File

@ -25,7 +25,7 @@ __all__ = ["OptimizerProcess", "ParameterProcess"]
class OptimizerProcess:
"""
r"""
Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags
and creating new optimizers.
@ -68,7 +68,12 @@ class OptimizerProcess:
self.origin_params = opt.init_params["params"]
def build_params_dict(self, network):
"""Build the params dict of the network"""
r"""
Build the parameter's dict of the network.
Inputs:
- **network** (Cell) - The training network.
"""
cells = network.cells_and_names()
params_dict = {}
for _, cell in cells:
@ -77,7 +82,13 @@ class OptimizerProcess:
return params_dict
def build_gc_params_group(self, params_dict, parameters):
"""Build the params group that needs gc"""
r"""
Build the parameter's group with grad centralization.
Inputs:
- **params_dict** (dict) - The network's parameter dict.
- **parameters** (list) - The network's parameter list.
"""
group_params = []
for group_param in parameters:
if 'order_params' in group_param.keys():
@ -107,7 +118,12 @@ class OptimizerProcess:
return group_params
def add_grad_centralization(self, network):
"""Add gradient centralization."""
r"""
Add gradient centralization.
Inputs:
- **network** (Cell) - The training network.
"""
params_dict = self.build_params_dict(network)
parameters = self.origin_params
@ -137,7 +153,7 @@ class OptimizerProcess:
class ParameterProcess:
"""
r"""
Process parameter for Boost. Currently, this class supports creating group parameters
and automatically setting gradient segmentation point.
@ -171,7 +187,13 @@ class ParameterProcess:
self._parameter_indices = 1
def assign_parameter_group(self, parameters, split_point=None):
"""Assign parameter group."""
r"""
Assign parameter group.
Inputs:
- **parameters** (list) - The network's parameter list.
- **split_point** (list) - The gradient split point of this network. default: None.
"""
if not isinstance(parameters, (list, tuple)) or not parameters:
return parameters
@ -187,7 +209,13 @@ class ParameterProcess:
return parameters
def generate_group_params(self, parameters, origin_params):
"""Generate group parameters."""
r"""
Generate group parameters.
Inputs:
- **parameters** (list) - The network's parameter list.
- **origin_params** (list) - The network's origin parameter list.
"""
origin_params_copy = origin_params
if origin_params_copy is not None:
if not isinstance(origin_params_copy, list):

View File

@ -37,7 +37,7 @@ _boost_config_level = {
class AutoBoost:
"""
r"""
Provide auto accelerating for network.
Args:
@ -68,7 +68,13 @@ class AutoBoost:
self._boost_config_func_map[key](self, val)
def network_auto_process_train(self, network, optimizer):
"""Network train."""
r"""
Boost network train.
Inputs:
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
"""
if self._boost_config["less_bn"]:
network = LessBN(network, fn_flag=self._fn_flag)
optimizer_process = OptimizerProcess(optimizer)
@ -90,7 +96,12 @@ class AutoBoost:
return network, optimizer
def network_auto_process_eval(self, network):
"""Network eval."""
r"""
Boost network eval.
Args:
- **network** (Cell) - The inference network.
"""
if self._boost_config["less_bn"]:
network = LessBN(network)

View File

@ -209,7 +209,12 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss
def gradient_freeze_process(self, *inputs):
"""gradient freeze algorithm process."""
r"""
Gradient freeze algorithm process.
Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
"""
if self.train_strategy is None:
step = self.step
max_index = len(self.freeze_nets)
@ -224,7 +229,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss
def gradient_accumulation_process(self, loss, grads):
"""gradient accumulation algorithm process."""
r"""
Gradient accumulation algorithm process.
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **grads** (Tuple(Tensor)) - Tuple of gradient tensors.
"""
loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step),
self.grad_accumulation, grads))
self.accumulation_step += 1
@ -242,7 +253,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss
def adasum_process(self, loss, grads):
"""adasum algorithm process."""
r"""
Adasum algorithm process.
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **grads** (Tuple(Tensor)) - Tuple of gradient tensors.
"""
loss = F.depend(loss, self.optimizer(grads))
rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
grad_clone = F.depend(self.grad_clone, loss)
@ -261,7 +278,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
return loss
def check_adasum_enable(self, optimizer, reducer_flag):
"""check adasum enable."""
r"""
Check adasum enable.
Inputs:
- **optimizer** (Union[Cell]) - Optimizer for updating the weights.
- **reducer_flag** (bool) - Reducer flag.
"""
if not getattr(optimizer, "adasum", None) or not reducer_flag:
return False
_rank_size = get_group_size()
@ -280,7 +303,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
must be provied by Cell type of `scale_sense`.
must be provide by Cell type of `scale_sense`.
Args:
network (Cell): The training network. The network only supports single output.

View File

@ -114,7 +114,7 @@ class FreezeOpt(Cell):
class _TrainFreezeCell(Cell):
"""
r"""
Gradient freezing training network.
Args:
@ -157,7 +157,7 @@ class _TrainFreezeCell(Cell):
class GradientFreeze:
"""
r"""
Freezing the gradients of some layers randomly. The number and
probability of frozen layers can be configured by users
@ -180,7 +180,13 @@ class GradientFreeze:
self._param_processer = ParameterProcess()
def split_parameters_groups(self, net, freeze_para_groups_number):
"""Split parameter groups for gradients freezing training."""
r"""
Split parameter groups for gradients freezing training.
Inputs:
- **net** (Cell) - The training network.
- **freeze_para_groups_number** (int) - The number of gradient freeze groups.
"""
grouped_params = []
tmp = []
for para in net.trainable_params():
@ -201,7 +207,15 @@ class GradientFreeze:
return freeze_grouped_params
def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps):
"""Generate index sequence for gradient freezing training."""
r"""
Generate index sequence for gradient freezing training.
Inputs:
- **parameter_groups_number** (int) - The number of parameter groups.
- **freeze_strategy** (int) - Gradient freeze grouping strategy, select from [0, 1].
- **freeze_p** (float) - Gradient freezing probability.
- **total_steps** (int) - Total training steps.
"""
total_step = int(total_steps * 1.01)
if parameter_groups_number <= 1:
return [0 for _ in range(total_step)]
@ -235,7 +249,13 @@ class GradientFreeze:
f"Unsupported freezing training strategy '{freeze_strategy}'")
def freeze_generate(self, network, optimizer):
"""Generate freeze network and optimizer."""
r"""
Generate freeze network and optimizer.
Inputs:
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
"""
train_para_groups = self.split_parameters_groups(
network, self._param_groups)
for i in range(self._param_groups):
@ -250,7 +270,43 @@ class GradientFreeze:
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
max_accumulation_step=1):
"""Provide freeze network cell."""
r"""
Generate freeze network and optimizer.
Inputs:
- **reducer_flag** (bool) - Reducer flag.
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
- **sens** (Tensor) - Tensor with shape :math:`()`
- **grad** (Tuple(Tensor)) - Tuple of gradient tensors.
- **use_grad_accumulation** (bool) - Use gradient accumulation flag.
- **mean** (bool) - Gradients mean flag. default: None.
- **degree** (int) - Device number. default: None.
- **max_accumulation_step** (int) - Max accumulation steps. default: 1.
Examples:
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> import mindspore.ops as ops
>>> from mindspore.boost.grad_freeze import freeze_cell
>>>
>>> class Net(nn.Cell):
... def __init__(self, in_features, out_features):
... super(Net, self).__init__()
... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
... name='weight')
... self.matmul = ops.MatMul()
...
... def construct(self, x):
... output = self.matmul(x, self.weight)
... return output
...
>>> in_features, out_features = 16, 10
>>> network = Net(in_features, out_features)
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> grad = ops.GradOperation(get_by_list=True, sens_param=True)
>>> freeze_nets = freeze_cell(False, network, optimizer, 1.0, grad, False, None, None, 1)
"""
if reducer_flag:
param_processer = ParameterProcess()
grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters),

View File

@ -226,7 +226,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
must be provied by Cell type of `scale_sense`.
must be provide by Cell type of `scale_sense`.
Args:
network (Cell): The training network. The network only supports single output.