From 0c97835662a79f6d6bf5d68514d2d947d82b0266 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 30 Dec 2020 15:05:54 +0800 Subject: [PATCH] update control flow int adamweightdecay for bert --- akg | 2 +- .../gpu/cuda_impl/float_status_impl.cu | 2 - .../gpu/math/float_status_gpu_kernel.h | 2 + model_zoo/official/nlp/bert/run_pretrain.py | 17 +- model_zoo/official/nlp/bert/src/__init__.py | 4 +- model_zoo/official/nlp/bert/src/adam.py | 307 ++++++++++++++++++ .../nlp/bert/src/bert_for_pre_training.py | 114 +++++++ 7 files changed, 438 insertions(+), 10 deletions(-) create mode 100644 model_zoo/official/nlp/bert/src/adam.py diff --git a/akg b/akg index ae997e27b21..f4f118a2deb 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit ae997e27b217d6c8c7a6cbf6ef812186835d2bdf +Subproject commit f4f118a2debd2eacc3f2ab6dc31846f1e04d6e13 diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu index bc400eb7049..081754c87a5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu @@ -88,7 +88,6 @@ __global__ void IsFinite(const size_t size, const half* input, bool* out) { template __global__ void FloatStatus(const size_t size, const T* input, T* out) { - out[0] = 0; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { if (isinf(input[pos]) != 0 || isnan(input[pos])) { out[0] = 1; @@ -98,7 +97,6 @@ __global__ void FloatStatus(const size_t size, const T* input, T* out) { } template <> __global__ void FloatStatus(const size_t size, const half* input, half* out) { - out[0] = 0; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { out[0] = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h index 1d3cfcc8cef..b1e7e4fd0d6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h @@ -24,6 +24,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" namespace mindspore { namespace kernel { @@ -46,6 +47,7 @@ class FloatStatusGpuKernel : public GpuKernel { switch (kernel_name_) { case OP_STATUS: { T *output = GetDeviceAddress(outputs, 0); + FillDeviceArray(outputs[0]->size / sizeof(T), output, 0.0f, reinterpret_cast(stream_ptr)); CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); break; } diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index ccbb34d57a5..52c14a28193 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -32,7 +32,8 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore import log as logger from mindspore.common import set_seed from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ - BertTrainAccumulateStepsWithLossScaleCell + BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \ + AdamWeightDecayForBert from src.dataset import create_bert_dataset from src.config import cfg, bert_net_cfg from src.utils import LossCallBack, BertLearningRate @@ -83,8 +84,10 @@ def _get_optimizer(args_opt, network): group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] - - optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) + if args_opt.enable_lossscale == "true": + optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) + else: + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". format(cfg.optimizer)) @@ -206,8 +209,12 @@ def run_pretrain(): scale_window=cfg.scale_window) if args_opt.accumulation_steps <= 1: - net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, - scale_update_cell=update_cell) + if cfg.optimizer == 'AdamWeightDecay': + net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) + else: + net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) else: accumulation_steps = args_opt.accumulation_steps net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, diff --git a/model_zoo/official/nlp/bert/src/__init__.py b/model_zoo/official/nlp/bert/src/__init__.py index 1bff219f391..aa5003a2b2e 100644 --- a/model_zoo/official/nlp/bert/src/__init__.py +++ b/model_zoo/official/nlp/bert/src/__init__.py @@ -21,13 +21,13 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ SaturateCast, CreateAttentionMaskFromInputMask - +from .adam import AdamWeightDecayForBert __all__ = [ "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell", "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", "BertSelfAttention", "BertTransformer", "EmbeddingLookup", - "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", + "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask" ] diff --git a/model_zoo/official/nlp/bert/src/adam.py b/model_zoo/official/nlp/bert/src/adam.py new file mode 100644 index 00000000000..c7a952e2bb4 --- /dev/null +++ b/model_zoo/official/nlp/bert/src/adam.py @@ -0,0 +1,307 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag.""" +import numpy as np + +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.nn.optim.optimizer import Optimizer + +_adam_opt = C.MultitypeFuncGraph("adam_opt") +_scaler_one = Tensor(1, mstype.int32) +_scaler_ten = Tensor(10, mstype.float32) + + +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): + """ + Update parameters. + + Args: + beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). + beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). + eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. + lr (Tensor): Learning rate. + overflow (Tensor): Whether overflow occurs. + weight_decay (Number): Weight decay. Should be equal to or greater than 0. + param (Tensor): Parameters. + m (Tensor): m value of parameters. + v (Tensor): v value of parameters. + gradient (Tensor): Gradient of parameters. + decay_flag (bool): Applies weight decay or not. + optim_filter (bool): Applies parameter update or not. + + Returns: + Tensor, the new value of v after updating. + """ + if optim_filter: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + op_select = P.Select() + + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) + next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) + + next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) + + update = next_m / (eps + op_sqrt(next_v)) + if decay_flag: + update = op_mul(weight_decay, param_fp32) + update + + update_with_lr = op_mul(lr, update) + zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) + next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) + + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) + + return op_cast(next_param, F.dtype(param)) + return gradient + + +@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, + beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): + """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" + success = True + indices = gradient.indices + values = gradient.values + if ps_parameter and not cache_enable: + op_shape = P.Shape() + shapes = (op_shape(param), op_shape(m), op_shape(v), + op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), + op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices), shapes), param)) + return success + + if not target: + success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices)) + else: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + scatter_add = P.ScatterAdd(use_locking) + + assign_m = F.assign(m, op_mul(beta1, m)) + assign_v = F.assign(v, op_mul(beta2, v)) + + grad_indices = gradient.indices + grad_value = gradient.values + + next_m = scatter_add(m, + grad_indices, + op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) + + next_v = scatter_add(v, + grad_indices, + op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) + + if use_nesterov: + m_temp = next_m * _scaler_ten + assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) + div_value = scatter_add(m, + op_mul(grad_indices, _scaler_one), + op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) + param_update = div_value / (op_sqrt(next_v) + eps) + + m_recover = F.assign(m, m_temp / _scaler_ten) + + F.control_depend(m_temp, assign_m_nesterov) + F.control_depend(assign_m_nesterov, div_value) + F.control_depend(param_update, m_recover) + else: + param_update = next_m / (op_sqrt(next_v) + eps) + + lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) + + next_param = param - lr_t * param_update + + F.control_depend(assign_m, next_m) + F.control_depend(assign_v, next_v) + + success = F.depend(success, F.assign(param, next_param)) + success = F.depend(success, F.assign(m, next_m)) + success = F.depend(success, F.assign(v, next_v)) + + return success + + +@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, + beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, + moment1, moment2, ps_parameter, cache_enable): + """Apply adam optimizer to the weight parameter using Tensor.""" + success = True + if ps_parameter and not cache_enable: + op_shape = P.Shape() + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), + (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) + else: + success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, + eps, gradient)) + return success + + +@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor") +def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): + """Apply AdamOffload optimizer to the weight parameter using Tensor.""" + success = True + delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) + success = F.depend(success, F.assign_add(param, delat_param)) + return success + + +def _check_param_value(beta1, beta2, eps, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) + validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) + validator.check_positive_float(eps, "eps", prim_name) + +class AdamWeightDecayForBert(Optimizer): + """ + Implements the Adam algorithm to fix the weight decay. + + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and + the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters + which in the 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. + Should be in range (0.0, 1.0). + beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. + Should be in range (0.0, 1.0). + eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. + Should be greater than 0. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale. + + Outputs: + tuple[bool], all elements are True. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') + self.hyper_map = C.HyperMap() + self.op_select = P.Select() + self.op_cast = P.Cast() + self.op_reshape = P.Reshape() + self.op_shape = P.Shape() + + def construct(self, gradients, overflow): + """AdamWeightDecayForBert""" + lr = self.get_lr() + cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ + self.op_reshape(overflow, (())), mstype.bool_) + beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1) + beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2) + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + if self.use_parallel: + self.broadcast_params(optim_result) + return optim_result diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index a4fcbc0aa62..2ab077d4a40 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -440,6 +440,120 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ret = (loss, cond, scaling_sens) return F.depend(ret, succ) +class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow + condition as input. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + if context.get_context("device_target") == "GPU": + self.gpu_target = True + self.float_status = P.FloatStatus() + self.addn = P.AddN() + self.reshape = P.Reshape() + else: + self.gpu_target = False + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + init = False + if not self.gpu_target: + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if not self.gpu_target: + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) + flag_sum = self.addn(flag_sum) + flag_sum = self.reshape(flag_sum, (())) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if self.loss_scaling_manager is not None: + overflow = self.loss_scaling_manager(scaling_sens, cond) + succ = self.optimizer(grads, overflow) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) cast = P.Cast() update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")