fix boost docs problems.

This commit is contained in:
linqingke 2021-09-16 16:56:41 +08:00
parent c8b9d45abc
commit b3b319d5ca
9 changed files with 56 additions and 34 deletions

View File

@ -13,9 +13,11 @@
# limitations under the License.
# ============================================================================
"""
MindBoost(Beta Feature)
Boost provide auto accelerating for network, such as Less BN, Gradient Freeze, Gradient
accumulation and so on.
Provide auto accelerating for network, such as Less BN, Gradient Freeze.
Note:
This feature is a beta feature, and we are still improving its functionality.
"""
from .boost import *
from .base import *

View File

@ -108,7 +108,7 @@ def _adasum_opt_rollback_process(left_send, delta_w, send, recv):
class AdaSum(Cell):
"""
r"""
The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
parallel training of Deep Learning models.
@ -117,7 +117,13 @@ class AdaSum(Cell):
optimizer (Union[Cell]): Optimizer for updating the weights.
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Return:
Inputs:
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
- **parameters** (Tuple(Parameter)) - Tuple of current parameters.
- **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
Outputs:
- **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
"""
def __init__(self, rank, device_number, group_number, parameter_tuple):
super(AdaSum, self).__init__()

View File

@ -34,7 +34,7 @@ class OptimizerProcess:
Examples:
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> import mindspore.ops import ops
>>> from mindspore.boost import OptimizerProcess
>>>
>>> class Net(nn.Cell):
@ -142,7 +142,7 @@ class ParameterProcess:
Examples:
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> import mindspore.ops as ops
>>> from mindspore.boost import OptimizerProcess
>>>
>>> class Net(nn.Cell):

View File

@ -41,7 +41,8 @@ class AutoBoost:
Provide auto accelerating for network.
Args:
level (Str): boost config level.
level (str): boost config level.
kwargs (any): Additional configuration parameters related to boost.
"""
def __init__(self, level, kwargs):
if level not in _boost_config_level.keys():

View File

@ -100,7 +100,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
... return self._backbone
...
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = boost.BoostTrainOneStepCellTrainOneStepCell(loss_net, optim)
>>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0):
@ -117,9 +117,9 @@ class BoostTrainOneStepCell(TrainOneStepCell):
if self.max_accumulation_step <= 1:
self.max_accumulation_step = 1
self.use_grad_accumulation = False
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
self.accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
if self.use_grad_accumulation:
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self.grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self.freeze_nets = None
self.step = Parameter(Tensor(0, dtype=mstype.int32))
@ -134,13 +134,13 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.grad, self.use_grad_accumulation, self.mean, self.degree,
self.max_accumulation_step)
self.enable_adasum = self._check_adasum_enable(optimizer, self.reducer_flag)
self.enable_adasum = self.check_adasum_enable(optimizer, self.reducer_flag)
self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32))
if self.enable_adasum:
_rank = _get_global_rank()
_rank_size = get_group_size()
_device_number = 8
self._device_number = _device_number
self.device_number = _device_number
group_number = _rank_size // _device_number
self.server_rank = _rank % _device_number
@ -195,19 +195,19 @@ class BoostTrainOneStepCell(TrainOneStepCell):
def gradient_accumulation_process(self, loss, grads):
"""gradient accumulation algorithm process."""
loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self._max_accumulation_step),
self._grad_accumulation, grads))
self._accumulation_step += 1
loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step),
self.grad_accumulation, grads))
self.accumulation_step += 1
if self._accumulation_step >= self._max_accumulation_step:
if self.accumulation_step >= self.max_accumulation_step:
if self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, self._grad_accumulation))
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
else:
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
self._accumulation_step = 0
loss = F.depend(loss, self.optimizer(self.grad_accumulation))
self.accumulation_step = 0
if self._accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
if self.accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self.grad_accumulation))
return loss
@ -220,7 +220,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
adasum_res = self.adasum(delta_w, rank_weights, grad_clone)
sync_tensor = F.depend(self.sync_tensor, adasum_res)
sync_flag = self.adasum.sync_barrier(sync_tensor)
for i in range(self._device_number):
for i in range(self.device_number):
weight_tuple = self.weights[self.start[i] : self.end[i]]
node_rank = F.depend(weight_tuple, sync_flag)
update_weights = self.adasum.broadcast_list[i](node_rank)
@ -230,7 +230,8 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.hyper_map(F.partial(_save_weight), weight_tuple, update_weights)
return loss
def _check_adasum_enable(self, optimizer, reducer_flag):
def check_adasum_enable(self, optimizer, reducer_flag):
"""check adasum enable."""
if not getattr(optimizer, "adasum", None) or not reducer_flag:
return False
_rank_size = get_group_size()
@ -277,8 +278,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
Examples:
>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> from mindspore.nn.wrap.cell_wrapper import WithLossCell
>>> import mindspore.ops as ops
>>> from mindspore.nn import WithLossCell
>>> from mindspore.common import dtype as mstype
>>>
>>> class Net(nn.Cell):
@ -356,7 +357,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
loss = self.gradient_accumulation_process(loss, grads)
else:
if self.enable_adasum:
loss = F.depend(loss, self.adasum_process(loss, self._grad_accumulation))
loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
else:
loss = F.depend(loss, self.optimizer(grads))
return loss, cond, scaling_sens

View File

@ -39,8 +39,8 @@ class FreezeOpt(Cell):
Optimizer that supports gradients freezing training.
Args:
opt (Optimizer): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
train_parameter_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
opt (Cell): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
train_parameter_groups (Union[tuple, list]): Groups of parameters for gradients freezing training.
train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training.
Supported Platforms:
@ -120,10 +120,10 @@ class _TrainFreezeCell(Cell):
Args:
net (Cell): The training network.
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
grad (tuple(tensor)): The gradients of network parameters and inputs.
grad (tuple(Tensor)): The gradients of network parameters and inputs.
grad_reducer (Cell): Constructs a gradient reducer Cell, which applies communication and average operations on
single-process gradient values.
use_grad_accumulation (Bool): Whether use grad accumulation.
use_grad_accumulation (bool): Whether use grad accumulation.
optimizer (Union[Cell]): Optimizer for updating the weights.
max_accumulation_step (numbers.Number): Max grad accumulation steps. Default: 1.0
@ -162,9 +162,9 @@ class GradientFreeze:
probability of frozen layers can be configured by users
Args:
param_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
freeze_type (Int): Strategy of gradients freezing training.
freeze_p (FLoat): probability of gradients freezing training.
param_groups (Union[tuple, list]): Groups of parameters for gradients freezing training.
freeze_type (int): Strategy of gradients freezing training.
freeze_p (float): probability of gradients freezing training.
total_steps (numbers.Number): Steps of the whole training.
Examples:

View File

@ -89,6 +89,7 @@ class LessBN(Cell):
Args:
network (Cell): Network to be modified.
fn_flag (bool): Replace FC with FN. default: False.
Examples:
>>> network = boost.LessBN(network)

View File

@ -447,7 +447,7 @@ bool EliminateLoadBeforeAssigns(const FuncGraphManagerPtr &manager, const CNodeP
// If Load used by other nodes, keep load node.
auto assign_cnode = assign->cast<CNodePtr>();
if (OnlyUsedByOneNode(ref_node, assign_cnode)) {
manager->Replace(ref_node, parameter);
(void)manager->Replace(ref_node, parameter);
} else {
manager->SetEdge(assign, kInputIndex, parameter);
}

View File

@ -133,6 +133,17 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
O2 is recommended on GPU, O3 is recommended on Ascend.Property of `keep_batchnorm_fp32` , `cast_model_type`
and `loss_scale_manager` determined by `level` setting may be overwritten by settings in `kwargs` .
boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
training. Supports ["O0", "O1", "O2"]. Default: "O0".
- O0: Do not change.
- O1: Enable the boost mode, the performance is improved by about 20%, and
the accuracy is the same as the original accuracy.
- O2: Enable the boost mode, the performance is improved by about 30%, and
the accuracy is reduced by less than 3%.
If O1 or O2 mode is set, the boost related library will take effect automatically.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
to the type determined by `level` setting.