diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 304996a4c48..793d8eb4095 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ from ...common.parameter import Parameter from ...ops import functional as F from ...ops import composite as C from ...ops import operations as P -from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual from ...common import dtype as mstype _grad_scale = C.MultitypeFuncGraph("grad_scale") @@ -275,22 +274,12 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): def __init__(self, network, optimizer, scale_sense): super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) self.hyper_map = C.HyperMap() - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.gpu_target = False - self.alloc_status = NPUAllocFloatStatus() - self.get_status = NPUGetFloatStatus() - self.clear_status = NPUClearFloatStatus() - self.reduce_sum = ReduceSum(keep_dims=False) self.base = Tensor(1, mstype.float32) - self.less_equal = LessEqual() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.less_equal = P.LessEqual() self.allreduce = P.AllReduce() - self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE - + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.gpu_target = (context.get_context("device_target") == "GPU") self.loss_scaling_manager = None if isinstance(scale_sense, Cell): self.loss_scaling_manager = scale_sense @@ -307,43 +296,19 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) - init = False - if not self.gpu_target: - # init overflow buffer - init = self.alloc_status() - # clear overflow buffer after loss calculated - init = F.depend(init, loss) - clear_status = self.clear_status(init) - loss = F.depend(loss, clear_status) - scaling_sens = self.scale_sense + + status, scaling_sens = self.start_overflow(loss, scaling_sens) + scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) # apply grad reducer on grads grads = self.grad_reducer(grads) + # get the overflow buffer - if not self.gpu_target: - # get overflow status after grads calculated - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - # sum overflow buffer elements, 0:not overflow , >0:overflow - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - # convert flag_sum to scalar - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if self.loss_scaling_manager is not None: - overflow = self.loss_scaling_manager(self.scale_sense, cond) + cond = self.detect_overflow(status, grads) + overflow = self.process_loss_scale(cond) # if there is no overflow, do optimize if not overflow: loss = F.depend(loss, self.optimizer(grads)) @@ -356,3 +321,84 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): self.scale_sense.set_data(sens) else: raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) + + def start_overflow(self, pre_cond, compute_input): + """ + Start floating-point overflow detection. Create and clear the overflow detection state. + + Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time. + Taking this situation as an example, we need to execute state clearing after loss calculation and then detect + overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss + function, and compute_input should be the input of gradients-computing function. + + Args: + pre_cond(object): A precondition for starting overflow detection. It determines the executing order of + overflow state clearing and prior processions. It makes sure that the function 'start_overflow' clears + status after finishing the process of precondition. + compute_input(object): The input of subsequent process. Overflow detection should be performed on a certain + computation. Set `compute_input` as the input of the computation, to ensure overflow status is cleared + before executing the computation. + + Returns: + Tuple[object, object], the first value is False for GPU backend, while it is a instance of + NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection. + The second value is the same as the input of `compute_input`, but contains some information about the + execution order. + """ + status = False + if not self.gpu_target: + # init overflow buffer + status = P.NPUAllocFloatStatus()() + status = F.depend(status, pre_cond) + # clear overflow buffer + clear_status = P.NPUClearFloatStatus()(status) + compute_input = F.depend(compute_input, clear_status) + return status, compute_input + + def detect_overflow(self, status, compute_output): + """ + Detect floating-point overflow status. + + Get overflow results after executing the target process for overflow detection. + + Args: + status(object): A status instance used to detect the overflow. + compute_output: Overflow detection should be performed on a certain computation. Set `compute_output` as + the output of the computation, to ensure overflow status is acquired before executing the computation. + + Returns: + bool, whether the overflow occurs or not. + """ + if not self.gpu_target: + status = F.depend(status, compute_output) + get_status = P.NPUGetFloatStatus()(status) + status = F.depend(status, get_status) + # sum overflow buffer elements, 0:not overflow , >0:overflow + flag_sum = self.reduce_sum(status, (0,)) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output) + flag_sum = P.AddN()(flag_sum) + # convert flag_sum to scalar + flag_sum = P.Reshape()(flag_sum, (())) + + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + overflow = self.less_equal(self.base, flag_reduce) + else: + overflow = self.less_equal(self.base, flag_sum) + return overflow + + def process_loss_scale(self, overflow): + """ + Calculate loss scale according to the overflow. + + Args: + overflow(bool): Whether the overflow occurs or not. + + Returns: + bool, overflow value. + """ + if self.loss_scaling_manager is not None: + return self.loss_scaling_manager(self.scale_sense, overflow) + return overflow diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 5113ea59275..314647f2c5f 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -330,7 +330,7 @@ def _tensor_grad_overflow(grad): return grad_overflow(grad) -class BertTrainOneStepWithLossScaleCell(nn.Cell): +class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): """ Encapsulation class of bert network training. @@ -344,39 +344,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity + super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) + self.cast = P.Cast() self.degree = 1 if self.reducer_flag: self.degree = get_group_size() self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.gpu_target = False - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() + self.loss_scale = None self.loss_scaling_manager = scale_update_cell if scale_update_cell: @@ -404,13 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens - init = False - if not self.gpu_target: - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -424,21 +392,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) + + cond = self.detect_overflow(status, grads) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) @@ -449,7 +404,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ret = (loss, cond, scaling_sens) return F.depend(ret, succ) -class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): + +class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell): """ Encapsulation class of bert network training. @@ -464,40 +420,12 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity + super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell) + self.cast = P.Cast() self.degree = 1 if self.reducer_flag: self.degree = get_group_size() self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.gpu_target = False - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() self.loss_scale = None self.loss_scaling_manager = scale_update_cell if scale_update_cell: @@ -525,14 +453,8 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens - init = False - if not self.gpu_target: - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -546,21 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) + cond = self.detect_overflow(status, grads) overflow = cond if self.loss_scaling_manager is not None: overflow = self.loss_scaling_manager(scaling_sens, cond)