!24051 Pclint and docs fix.

Merge pull request !24051 from linqingke/bug_fix
This commit is contained in:
i-robot 2021-09-26 06:28:07 +00:00 committed by Gitee
commit 11c3e1c193
3 changed files with 141 additions and 14 deletions

View File

@ -19,13 +19,13 @@ accumulation and so on.
Note:
This feature is a beta feature, and we are still improving its functionality.
"""
from .boost import *
from .base import *
from .boost_cell_wrapper import *
from .less_batch_normalization import *
from .grad_freeze import *
from .grad_accumulation import *
from .adasum import *
from .boost import AutoBoost
from .base import OptimizerProcess, ParameterProcess
from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell
from .less_batch_normalization import LessBN
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
from .grad_accumulation import GradientAccumulation
from .adasum import AdaSum
__all__ = ['AutoBoost',

View File

@ -51,6 +51,36 @@ def _save_weight_process(new_parameter, old_parameter):
return P.Assign()(new_parameter, old_parameter)
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
return RowTensor(grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
@_grad_overflow.register("RowTensor")
def _tensor_grad_overflow_row_tensor(grad):
return grad_overflow(grad.values)
class BoostTrainOneStepCell(TrainOneStepCell):
r"""
Boost Network training package class.
@ -77,6 +107,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import boost
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
@ -280,7 +311,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
>>> from mindspore import Tensor, Parameter, nn
>>> import mindspore.ops as ops
>>> from mindspore.nn import WithLossCell
>>> from mindspore.common import dtype as mstype
>>> from mindspore import dtype as mstype
>>> from mindspore import boost
>>>
>>> class Net(nn.Cell):
... def __init__(self, in_features, out_features):
@ -301,8 +333,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
>>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
>>> input = Tensor(np.ones([out_features, in_features]), mstype.float32)
>>> labels = Tensor(np.ones([out_features,]), mstype.float32)
>>> output = train_network(input, labels)
>>>
>>> #2) when the type of scale_sense is Tensor:
@ -342,15 +374,15 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
loss = self.network(*inputs)
scaling_sens = self.scale_sense
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
status, scaling_sens = self._start_overflow_check(loss, scaling_sens)
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
cond = self._get_overflow_status(status, grads)
overflow = self._process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
if self.use_grad_accumulation:
@ -361,3 +393,98 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
else:
loss = F.depend(loss, self.optimizer(grads))
return loss, cond, scaling_sens
def _set_sense_scale(self, sens):
"""
If the user has set the sens in the training process and wants to reassign the value, he can call
this function again to make modification, and sens needs to be of type Tensor.
Inputs:
- **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`.
"""
if self.scale_sense and isinstance(sens, Tensor):
self.scale_sense.set_data(sens)
else:
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
def _start_overflow_check(self, pre_cond, compute_input):
"""
Start floating-point overflow detection. Create and clear the overflow detection state.
Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
function, and compute_input should be the input of gradients-computing function.
Inputs:
- **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order
of overflow state clearing and prior processions. It makes sure that the function 'start_overflow'
clears status after finishing the process of precondition.
- **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a
certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is
cleared before executing the computation.
Outputs:
Tuple[object, object], the first value is False for GPU backend, while it is a instance of
NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
The second value is the same as the input of `compute_input`, but contains some information about the
execution order.
"""
status = False
if not self.gpu_target:
# init overflow buffer
status = P.NPUAllocFloatStatus()()
status = F.depend(status, pre_cond)
# clear overflow buffer
clear_status = P.NPUClearFloatStatus()(status)
compute_input = F.depend(compute_input, clear_status)
return status, compute_input
def _get_overflow_status(self, status, compute_output):
"""
Get floating-point overflow status.
Get overflow results after executing the target process for overflow detection.
Inputs:
- **status** (object) - A status instance used to detect the overflow.
- **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output`
as the output of the computation, to ensure overflow status is acquired before executing the
computation.
Outputs:
bool, whether the overflow occurs or not.
"""
if not self.gpu_target:
status = F.depend(status, compute_output)
get_status = P.NPUGetFloatStatus()(status)
status = F.depend(status, get_status)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(status, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
flag_sum = P.AddN()(flag_sum)
# convert flag_sum to scalar
flag_sum = P.Reshape()(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
overflow = self.less_equal(self.base, flag_reduce)
else:
overflow = self.less_equal(self.base, flag_sum)
return overflow
def _process_loss_scale(self, overflow):
"""
Calculate loss scale according to the overflow.
Inputs:
- **overflow** (bool) - Whether the overflow occurs or not.
Outputs:
bool, overflow value.
"""
if self.loss_scaling_manager is not None:
return self.loss_scaling_manager(self.scale_sense, overflow)
return overflow

View File

@ -23,7 +23,7 @@ namespace mindspore {
namespace opt {
namespace irpass {
namespace {
enum RemoveNodeType { kOtherNode = 0, kOptimizerNode };
enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode };
const char kLessBatchNormalizationPassName[] = "less_bn";
constexpr auto kValidResidualStructureIndex = 1;
constexpr auto kBNParametersStartIndex = 2;