diff --git a/mindspore/boost/__init__.py b/mindspore/boost/__init__.py index 40e438502a6..01411e7a3e2 100644 --- a/mindspore/boost/__init__.py +++ b/mindspore/boost/__init__.py @@ -19,13 +19,13 @@ accumulation and so on. Note: This feature is a beta feature, and we are still improving its functionality. """ -from .boost import * -from .base import * -from .boost_cell_wrapper import * -from .less_batch_normalization import * -from .grad_freeze import * -from .grad_accumulation import * -from .adasum import * +from .boost import AutoBoost +from .base import OptimizerProcess, ParameterProcess +from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell +from .less_batch_normalization import LessBN +from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell +from .grad_accumulation import GradientAccumulation +from .adasum import AdaSum __all__ = ['AutoBoost', diff --git a/mindspore/boost/boost_cell_wrapper.py b/mindspore/boost/boost_cell_wrapper.py index ba85cf2f271..1a3ad4bf798 100644 --- a/mindspore/boost/boost_cell_wrapper.py +++ b/mindspore/boost/boost_cell_wrapper.py @@ -51,6 +51,36 @@ def _save_weight_process(new_parameter, old_parameter): return P.Assign()(new_parameter, old_parameter) +_grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@_grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * F.cast(reciprocal(scale), F.dtype(grad)) + + +@_grad_scale.register("Tensor", "RowTensor") +def tensor_grad_scale_row_tensor(scale, grad): + return RowTensor(grad.indices, + grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), + grad.dense_shape) + + +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() + + +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + + +@_grad_overflow.register("RowTensor") +def _tensor_grad_overflow_row_tensor(grad): + return grad_overflow(grad.values) + + class BoostTrainOneStepCell(TrainOneStepCell): r""" Boost Network training package class. @@ -77,6 +107,7 @@ class BoostTrainOneStepCell(TrainOneStepCell): ``Ascend`` ``GPU`` ``CPU`` Examples: + >>> from mindspore import boost >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) @@ -280,7 +311,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): >>> from mindspore import Tensor, Parameter, nn >>> import mindspore.ops as ops >>> from mindspore.nn import WithLossCell - >>> from mindspore.common import dtype as mstype + >>> from mindspore import dtype as mstype + >>> from mindspore import boost >>> >>> class Net(nn.Cell): ... def __init__(self, in_features, out_features): @@ -301,8 +333,8 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): >>> net_with_loss = WithLossCell(net, loss) >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) - >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) - >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) + >>> input = Tensor(np.ones([out_features, in_features]), mstype.float32) + >>> labels = Tensor(np.ones([out_features,]), mstype.float32) >>> output = train_network(input, labels) >>> >>> #2) when the type of scale_sense is Tensor: @@ -342,15 +374,15 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): loss = self.network(*inputs) scaling_sens = self.scale_sense - status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + status, scaling_sens = self._start_overflow_check(loss, scaling_sens) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) # get the overflow buffer - cond = self.get_overflow_status(status, grads) - overflow = self.process_loss_scale(cond) + cond = self._get_overflow_status(status, grads) + overflow = self._process_loss_scale(cond) # if there is no overflow, do optimize if not overflow: if self.use_grad_accumulation: @@ -361,3 +393,98 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): else: loss = F.depend(loss, self.optimizer(grads)) return loss, cond, scaling_sens + + def _set_sense_scale(self, sens): + """ + If the user has set the sens in the training process and wants to reassign the value, he can call + this function again to make modification, and sens needs to be of type Tensor. + + Inputs: + - **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`. + """ + if self.scale_sense and isinstance(sens, Tensor): + self.scale_sense.set_data(sens) + else: + raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) + + def _start_overflow_check(self, pre_cond, compute_input): + """ + Start floating-point overflow detection. Create and clear the overflow detection state. + + Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time. + Taking this situation as an example, we need to execute state clearing after loss calculation and then detect + overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss + function, and compute_input should be the input of gradients-computing function. + + Inputs: + - **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order + of overflow state clearing and prior processions. It makes sure that the function 'start_overflow' + clears status after finishing the process of precondition. + - **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a + certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is + cleared before executing the computation. + + Outputs: + Tuple[object, object], the first value is False for GPU backend, while it is a instance of + NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection. + The second value is the same as the input of `compute_input`, but contains some information about the + execution order. + """ + status = False + if not self.gpu_target: + # init overflow buffer + status = P.NPUAllocFloatStatus()() + status = F.depend(status, pre_cond) + # clear overflow buffer + clear_status = P.NPUClearFloatStatus()(status) + compute_input = F.depend(compute_input, clear_status) + return status, compute_input + + def _get_overflow_status(self, status, compute_output): + """ + Get floating-point overflow status. + + Get overflow results after executing the target process for overflow detection. + + Inputs: + - **status** (object) - A status instance used to detect the overflow. + - **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output` + as the output of the computation, to ensure overflow status is acquired before executing the + computation. + + Outputs: + bool, whether the overflow occurs or not. + """ + if not self.gpu_target: + status = F.depend(status, compute_output) + get_status = P.NPUGetFloatStatus()(status) + status = F.depend(status, get_status) + # sum overflow buffer elements, 0:not overflow , >0:overflow + flag_sum = self.reduce_sum(status, (0,)) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output) + flag_sum = P.AddN()(flag_sum) + # convert flag_sum to scalar + flag_sum = P.Reshape()(flag_sum, (())) + + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + overflow = self.less_equal(self.base, flag_reduce) + else: + overflow = self.less_equal(self.base, flag_sum) + return overflow + + def _process_loss_scale(self, overflow): + """ + Calculate loss scale according to the overflow. + + Inputs: + - **overflow** (bool) - Whether the overflow occurs or not. + + Outputs: + bool, overflow value. + """ + if self.loss_scaling_manager is not None: + return self.loss_scaling_manager(self.scale_sense, overflow) + return overflow diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc index f0adf80adbe..8fd894b408a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace opt { namespace irpass { namespace { -enum RemoveNodeType { kOtherNode = 0, kOptimizerNode }; +enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode }; const char kLessBatchNormalizationPassName[] = "less_bn"; constexpr auto kValidResidualStructureIndex = 1; constexpr auto kBNParametersStartIndex = 2;