all reduce after each step in gradients accumulation mode for bert

This commit is contained in:
shibeiji 2020-09-24 16:03:04 +08:00
parent 03e655f14a
commit f0b08e8bff
5 changed files with 211 additions and 34 deletions

View File

@ -239,30 +239,32 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--accumulation_steps N] [--accumulation_steps N]
[--allreduce_post_accumulation ALLREDUCE_POST_ACCUMULATION]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH] [--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N] [--save_checkpoint_steps N] [--save_checkpoint_num N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N] [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N]
options: options:
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
--epoch_size epoch size: N, default is 1 --epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1 --device_num number of used devices: N, default is 1
--device_id device id: N, default is 0 --device_id device id: N, default is 0
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true" --enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true" --do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true" --enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1 --data_sink_steps set data sink steps: N, default is 1
--accumulation_steps accumulate gradients N times before weight update: N, default is 1 --accumulation_steps accumulate gradients N times before weight update: N, default is 1
--save_checkpoint_path path to save checkpoint files: PATH, default is "" --allreduce_post_accumulation allreduce after accumulation of N steps or after each step: "true" | "false", default is "true"
--load_checkpoint_path path to load checkpoint files: PATH, default is "" --save_checkpoint_path path to save checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 --load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_num number for saving checkpoint files: N, default is 1 --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--train_steps Training Steps: N, default is -1 --save_checkpoint_num number for saving checkpoint files: N, default is 1
--data_dir path to dataset directory: PATH, default is "" --train_steps Training Steps: N, default is -1
--schema_dir path to schema.json file, PATH, default is "" --data_dir path to dataset directory: PATH, default is ""
--schema_dir path to schema.json file, PATH, default is ""
``` ```
### Fine-Tuning and Evaluation ### Fine-Tuning and Evaluation

View File

@ -32,7 +32,9 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger from mindspore import log as logger
from mindspore.common import set_seed from mindspore.common import set_seed
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \ BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell, \
BertTrainOneStepWithLossScaleCellForAdam, \
AdamWeightDecayForBert AdamWeightDecayForBert
from src.dataset import create_bert_dataset from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg from src.config import cfg, bert_net_cfg
@ -122,6 +124,8 @@ def run_pretrain():
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--accumulation_steps", type=int, default="1", parser.add_argument("--accumulation_steps", type=int, default="1",
help="Accumulating gradients N times before weight update, default is 1.") help="Accumulating gradients N times before weight update, default is 1.")
parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"],
help="Whether to allreduce after accumulation of N steps or after each step, default is true.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
@ -207,8 +211,9 @@ def run_pretrain():
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor, scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window) scale_window=cfg.scale_window)
accumulation_steps = args_opt.accumulation_steps
if args_opt.accumulation_steps <= 1: enable_global_norm = cfg.enable_global_norm
if accumulation_steps <= 1:
if cfg.optimizer == 'AdamWeightDecay': if cfg.optimizer == 'AdamWeightDecay':
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer, net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell) scale_update_cell=update_cell)
@ -216,11 +221,13 @@ def run_pretrain():
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell) scale_update_cell=update_cell)
else: else:
accumulation_steps = args_opt.accumulation_steps allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else
scale_update_cell=update_cell, BertTrainAccumulationAllReduceEachWithLossScaleCell)
accumulation_steps=accumulation_steps, net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer,
enable_global_norm=cfg.enable_global_norm) scale_update_cell=update_cell,
accumulation_steps=accumulation_steps,
enable_global_norm=enable_global_norm)
else: else:
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)

View File

@ -7,6 +7,7 @@ do_shuffle=true
enable_data_sink=true enable_data_sink=true
data_sink_steps=100 data_sink_steps=100
accumulation_steps=1 accumulation_steps=1
allreduce_post_accumulation=true
save_checkpoint_path=./ save_checkpoint_path=./
save_checkpoint_steps=10000 save_checkpoint_steps=10000
save_checkpoint_num=1 save_checkpoint_num=1

View File

@ -16,7 +16,8 @@
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
@ -25,7 +26,8 @@ from .adam import AdamWeightDecayForBert
__all__ = [ __all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell", "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulationAllReduceEachWithLossScaleCell",
"BertTrainAccumulationAllReducePostWithLossScaleCell",
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
"BertSelfAttention", "BertTransformer", "EmbeddingLookup", "BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",

View File

@ -556,11 +556,24 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
return F.depend(ret, succ) return F.depend(ret, succ)
cast = P.Cast() cast = P.Cast()
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") add_grads = C.MultitypeFuncGraph("add_grads")
@add_grads.register("Tensor", "Tensor")
def _add_grads(accu_grad, grad):
return accu_grad + cast(grad, mstype.float32)
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor") @update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad): def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
@accumulate_accu_grads.register("Tensor", "Tensor")
def _accumulate_accu_grads(accu_grad, grad):
succ = True succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
@ -575,13 +588,17 @@ def _reset_accu_grads(accu_grad):
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
""" """
Encapsulation class of bert network training. Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct Append an optimizer to the training network after that the construct
function can be called to create the backward graph. To mimic higher batch size, gradients are function can be called to create the backward graph.
accumulated N times before weight update.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will only be implemented in the weight updated step,
i.e. the sub-step after gradients accumulated N times.
Args: Args:
network (Cell): The training network. Note that loss function should have been added. network (Cell): The training network. Note that loss function should have been added.
@ -591,7 +608,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
batch_size * accumulation_steps. Default: 1. batch_size * accumulation_steps. Default: 1.
""" """
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad() self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
@ -680,7 +697,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
self.cast(scaling_sens, self.cast(scaling_sens,
mstype.float32)) mstype.float32))
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ) mean_loss = F.depend(mean_loss, accu_succ)
self.get_status(init) self.get_status(init)
@ -716,3 +733,151 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
ret = (mean_loss, overflow, scaling_sens) ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ) return F.depend(ret, succ)
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
will be overided by backend optimization pass.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32))
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
flag_reduce = self.overflow_reducer(flag_sum)
overflow = self.less_equal(self.base, flag_reduce)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
overflow = self.reshape(overflow, (()))
if is_accu_step:
succ = False
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
succ = F.depend(succ, accu_succ)
else:
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
if self.enable_global_norm:
grads = C.clip_by_global_norm(grads, 1.0, None)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
succ = self.optimizer(grads)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
succ = F.depend(succ, accu_succ)
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)