From 40fc11e9a49309ee60dc021063e31c491c4e8ea9 Mon Sep 17 00:00:00 2001 From: shibeiji Date: Tue, 25 Aug 2020 10:54:33 +0800 Subject: [PATCH] Mimic higher batch size by accumulating gradients N times before weight update --- mindspore/ops/functional.py | 1 + model_zoo/official/nlp/bert/README.md | 2 + model_zoo/official/nlp/bert/run_pretrain.py | 26 ++- .../hyper_parameter_config.ini | 3 +- .../scripts/run_standalone_pretrain_ascend.sh | 1 + model_zoo/official/nlp/bert/src/__init__.py | 6 +- .../nlp/bert/src/bert_for_pre_training.py | 161 ++++++++++++++++++ model_zoo/official/nlp/bert/src/config.py | 2 +- 8 files changed, 195 insertions(+), 7 deletions(-) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 8eb7a71b77b..13a26b1e057 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -63,6 +63,7 @@ check_bprop = P.CheckBprop() equal = P.Equal() not_equal = P.NotEqual() assign_sub = P.AssignSub() +assign_add = P.AssignAdd() assign = P.Assign() square = P.Square() sqrt = P.Sqrt() diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md index 8325bb7c70d..57b68fe7308 100644 --- a/model_zoo/official/nlp/bert/README.md +++ b/model_zoo/official/nlp/bert/README.md @@ -123,6 +123,7 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n [--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET] [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] + [--accumulation_steps N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH] [--load_checkpoint_path LOAD_CHECKPOINT_PATH] [--save_checkpoint_steps N] [--save_checkpoint_num N] @@ -139,6 +140,7 @@ options: --do_shuffle enable shuffle: "true" | "false", default is "true" --enable_data_sink enable data sink: "true" | "false", default is "true" --data_sink_steps set data sink steps: N, default is 1 + --accumulation_steps accumulate gradients N times before weight update: N, default is 1 --save_checkpoint_path path to save checkpoint files: PATH, default is "" --load_checkpoint_path path to load checkpoint files: PATH, default is "" --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 06cf905fcfc..73b1021003e 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -30,7 +30,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore import log as logger -from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ + BertTrainAccumulateStepsWithLossScaleCell from src.dataset import create_bert_dataset from src.config import cfg, bert_net_cfg from src.utils import LossCallBack, BertLearningRate @@ -51,6 +52,8 @@ def run_pretrain(): parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") + parser.add_argument("--accumulation_steps", type=int, default="1", + help="Accumulating gradients N times before weight update, default is 1.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " @@ -98,6 +101,16 @@ def run_pretrain(): logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 + if args_opt.accumulation_steps > 1: + logger.info("accumulation steps: {}".format(args_opt.accumulation_steps)) + logger.info("global batch size: {}".format(bert_net_cfg.batch_size * args_opt.accumulation_steps)) + if args_opt.enable_data_sink == "true": + args_opt.data_sink_steps *= args_opt.accumulation_steps + logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) + if args_opt.enable_save_ckpt == "true": + args_opt.save_checkpoint_steps *= args_opt.accumulation_steps + logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps)) + ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) @@ -157,8 +170,15 @@ def run_pretrain(): update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) - net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, - scale_update_cell=update_cell) + + if args_opt.accumulation_steps <= 1: + net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell) + else: + accumulation_steps = args_opt.accumulation_steps + net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, + scale_update_cell=update_cell, + accumulation_steps=accumulation_steps) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini index 2298f83509b..1af5bdbae40 100644 --- a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini @@ -6,6 +6,7 @@ enable_lossscale=true do_shuffle=true enable_data_sink=true data_sink_steps=100 +accumulation_steps=1 save_checkpoint_path=./checkpoint/ save_checkpoint_steps=10000 -save_checkpoint_num=1 \ No newline at end of file +save_checkpoint_num=1 diff --git a/model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh b/model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh index c6048f4f678..ae07a1bda9f 100644 --- a/model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh +++ b/model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh @@ -39,6 +39,7 @@ python ${PROJECT_DIR}/../run_pretrain.py \ --do_shuffle="true" \ --enable_data_sink="true" \ --data_sink_steps=1 \ + --accumulation_steps=1 \ --load_checkpoint_path="" \ --save_checkpoint_steps=10000 \ --save_checkpoint_num=1 \ diff --git a/model_zoo/official/nlp/bert/src/__init__.py b/model_zoo/official/nlp/bert/src/__init__.py index 4f4584a4b48..1bff219f391 100644 --- a/model_zoo/official/nlp/bert/src/__init__.py +++ b/model_zoo/official/nlp/bert/src/__init__.py @@ -15,7 +15,8 @@ """Bert Init.""" from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ - BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell + BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ + BertTrainAccumulateStepsWithLossScaleCell from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ @@ -23,7 +24,8 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ __all__ = [ "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", - "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell", + "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", + "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell", "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", "BertSelfAttention", "BertTransformer", "EmbeddingLookup", "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 14e00281d0c..8607c3ba872 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -438,3 +438,164 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return F.depend(ret, succ) + + +cast = P.Cast() +update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") + + +@update_accu_grads.register("Tensor", "Tensor") +def _update_accu_grads(accu_grad, grad): + succ = True + return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) + + +zeroslike = P.ZerosLike() +reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") + + +@reset_accu_grads.register("Tensor") +def _reset_accu_grads(accu_grad): + succ = True + return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) + + +class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. To mimic higher batch size, gradients are + accumulated N times before weight update. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1): + super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow") + self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss") + + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.loss = self.select(is_accu_step, self.loss + loss, loss) + mean_loss = self.loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + + accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) + mean_loss = F.depend(mean_loss, accu_succ) + + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + overflow = self.less_equal(self.base, flag_sum) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + + if is_accu_step: + succ = False + else: + # apply grad reducer on grads + grads = self.grad_reducer(self.accu_grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + accu_overflow = self.overflow_reducer(accu_overflow) + F.control_depend(grads, accu_overflow) + overflow = self.less_equal(self.base, accu_overflow) + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + overflow = F.depend(overflow, accu_succ) + overflow = self.reshape(overflow, (())) + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + + ret = (mean_loss, overflow, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index c8692d1216c..d0da37f60d6 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -50,7 +50,7 @@ cfg = edict({ ''' Including two kinds of network: \ -base: Goole BERT-base(the base version of BERT model). +base: Google BERT-base(the base version of BERT model). large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ Functional Relative Posetional Encoding as an effective positional encoding scheme). '''