forked from mindspore-Ecosystem/mindspore
!24051 Pclint and docs fix.
Merge pull request !24051 from linqingke/bug_fix
This commit is contained in:
commit
11c3e1c193
|
@ -19,13 +19,13 @@ accumulation and so on.
|
|||
Note:
|
||||
This feature is a beta feature, and we are still improving its functionality.
|
||||
"""
|
||||
from .boost import *
|
||||
from .base import *
|
||||
from .boost_cell_wrapper import *
|
||||
from .less_batch_normalization import *
|
||||
from .grad_freeze import *
|
||||
from .grad_accumulation import *
|
||||
from .adasum import *
|
||||
from .boost import AutoBoost
|
||||
from .base import OptimizerProcess, ParameterProcess
|
||||
from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell
|
||||
from .less_batch_normalization import LessBN
|
||||
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
|
||||
from .grad_accumulation import GradientAccumulation
|
||||
from .adasum import AdaSum
|
||||
|
||||
|
||||
__all__ = ['AutoBoost',
|
||||
|
|
|
@ -51,6 +51,36 @@ def _save_weight_process(new_parameter, old_parameter):
|
|||
return P.Assign()(new_parameter, old_parameter)
|
||||
|
||||
|
||||
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@_grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
|
||||
@_grad_scale.register("Tensor", "RowTensor")
|
||||
def tensor_grad_scale_row_tensor(scale, grad):
|
||||
return RowTensor(grad.indices,
|
||||
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
|
||||
grad.dense_shape)
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
@_grad_overflow.register("RowTensor")
|
||||
def _tensor_grad_overflow_row_tensor(grad):
|
||||
return grad_overflow(grad.values)
|
||||
|
||||
|
||||
class BoostTrainOneStepCell(TrainOneStepCell):
|
||||
r"""
|
||||
Boost Network training package class.
|
||||
|
@ -77,6 +107,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import boost
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
|
@ -280,7 +311,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|||
>>> from mindspore import Tensor, Parameter, nn
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore.nn import WithLossCell
|
||||
>>> from mindspore.common import dtype as mstype
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> from mindspore import boost
|
||||
>>>
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self, in_features, out_features):
|
||||
|
@ -301,8 +333,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|||
>>> net_with_loss = WithLossCell(net, loss)
|
||||
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
|
||||
>>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
|
||||
>>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
|
||||
>>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
|
||||
>>> input = Tensor(np.ones([out_features, in_features]), mstype.float32)
|
||||
>>> labels = Tensor(np.ones([out_features,]), mstype.float32)
|
||||
>>> output = train_network(input, labels)
|
||||
>>>
|
||||
>>> #2) when the type of scale_sense is Tensor:
|
||||
|
@ -342,15 +374,15 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|||
loss = self.network(*inputs)
|
||||
scaling_sens = self.scale_sense
|
||||
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
status, scaling_sens = self._start_overflow_check(loss, scaling_sens)
|
||||
|
||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
||||
|
||||
# get the overflow buffer
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = self.process_loss_scale(cond)
|
||||
cond = self._get_overflow_status(status, grads)
|
||||
overflow = self._process_loss_scale(cond)
|
||||
# if there is no overflow, do optimize
|
||||
if not overflow:
|
||||
if self.use_grad_accumulation:
|
||||
|
@ -361,3 +393,98 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|||
else:
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
return loss, cond, scaling_sens
|
||||
|
||||
def _set_sense_scale(self, sens):
|
||||
"""
|
||||
If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||
this function again to make modification, and sens needs to be of type Tensor.
|
||||
|
||||
Inputs:
|
||||
- **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`.
|
||||
"""
|
||||
if self.scale_sense and isinstance(sens, Tensor):
|
||||
self.scale_sense.set_data(sens)
|
||||
else:
|
||||
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
||||
|
||||
def _start_overflow_check(self, pre_cond, compute_input):
|
||||
"""
|
||||
Start floating-point overflow detection. Create and clear the overflow detection state.
|
||||
|
||||
Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
|
||||
Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
|
||||
overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
|
||||
function, and compute_input should be the input of gradients-computing function.
|
||||
|
||||
Inputs:
|
||||
- **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order
|
||||
of overflow state clearing and prior processions. It makes sure that the function 'start_overflow'
|
||||
clears status after finishing the process of precondition.
|
||||
- **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a
|
||||
certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is
|
||||
cleared before executing the computation.
|
||||
|
||||
Outputs:
|
||||
Tuple[object, object], the first value is False for GPU backend, while it is a instance of
|
||||
NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
|
||||
The second value is the same as the input of `compute_input`, but contains some information about the
|
||||
execution order.
|
||||
"""
|
||||
status = False
|
||||
if not self.gpu_target:
|
||||
# init overflow buffer
|
||||
status = P.NPUAllocFloatStatus()()
|
||||
status = F.depend(status, pre_cond)
|
||||
# clear overflow buffer
|
||||
clear_status = P.NPUClearFloatStatus()(status)
|
||||
compute_input = F.depend(compute_input, clear_status)
|
||||
return status, compute_input
|
||||
|
||||
def _get_overflow_status(self, status, compute_output):
|
||||
"""
|
||||
Get floating-point overflow status.
|
||||
|
||||
Get overflow results after executing the target process for overflow detection.
|
||||
|
||||
Inputs:
|
||||
- **status** (object) - A status instance used to detect the overflow.
|
||||
- **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output`
|
||||
as the output of the computation, to ensure overflow status is acquired before executing the
|
||||
computation.
|
||||
|
||||
Outputs:
|
||||
bool, whether the overflow occurs or not.
|
||||
"""
|
||||
if not self.gpu_target:
|
||||
status = F.depend(status, compute_output)
|
||||
get_status = P.NPUGetFloatStatus()(status)
|
||||
status = F.depend(status, get_status)
|
||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||
flag_sum = self.reduce_sum(status, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
|
||||
flag_sum = P.AddN()(flag_sum)
|
||||
# convert flag_sum to scalar
|
||||
flag_sum = P.Reshape()(flag_sum, (()))
|
||||
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
overflow = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
return overflow
|
||||
|
||||
def _process_loss_scale(self, overflow):
|
||||
"""
|
||||
Calculate loss scale according to the overflow.
|
||||
|
||||
Inputs:
|
||||
- **overflow** (bool) - Whether the overflow occurs or not.
|
||||
|
||||
Outputs:
|
||||
bool, overflow value.
|
||||
"""
|
||||
if self.loss_scaling_manager is not None:
|
||||
return self.loss_scaling_manager(self.scale_sense, overflow)
|
||||
return overflow
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace {
|
||||
enum RemoveNodeType { kOtherNode = 0, kOptimizerNode };
|
||||
enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode };
|
||||
const char kLessBatchNormalizationPassName[] = "less_bn";
|
||||
constexpr auto kValidResidualStructureIndex = 1;
|
||||
constexpr auto kBNParametersStartIndex = 2;
|
||||
|
|
Loading…
Reference in New Issue