all reduce after each step in gradients accumulation mode for bert

This commit is contained in:
shibeiji 2020-09-24 16:03:04 +08:00
parent 03e655f14a
commit f0b08e8bff
5 changed files with 211 additions and 34 deletions

View File

@ -239,30 +239,32 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--accumulation_steps N]
[--allreduce_post_accumulation ALLREDUCE_POST_ACCUMULATION]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N]
options:
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
--epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1
--device_id device id: N, default is 0
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--train_steps Training Steps: N, default is -1
--data_dir path to dataset directory: PATH, default is ""
--schema_dir path to schema.json file, PATH, default is ""
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
--epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1
--device_id device id: N, default is 0
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
--allreduce_post_accumulation allreduce after accumulation of N steps or after each step: "true" | "false", default is "true"
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--train_steps Training Steps: N, default is -1
--data_dir path to dataset directory: PATH, default is ""
--schema_dir path to schema.json file, PATH, default is ""
```
### Fine-Tuning and Evaluation

View File

@ -32,7 +32,9 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger
from mindspore.common import set_seed
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell, \
BertTrainOneStepWithLossScaleCellForAdam, \
AdamWeightDecayForBert
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
@ -122,6 +124,8 @@ def run_pretrain():
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--accumulation_steps", type=int, default="1",
help="Accumulating gradients N times before weight update, default is 1.")
parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"],
help="Whether to allreduce after accumulation of N steps or after each step, default is true.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
@ -207,8 +211,9 @@ def run_pretrain():
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
if args_opt.accumulation_steps <= 1:
accumulation_steps = args_opt.accumulation_steps
enable_global_norm = cfg.enable_global_norm
if accumulation_steps <= 1:
if cfg.optimizer == 'AdamWeightDecay':
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
@ -216,11 +221,13 @@ def run_pretrain():
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
accumulation_steps = args_opt.accumulation_steps
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell,
accumulation_steps=accumulation_steps,
enable_global_norm=cfg.enable_global_norm)
allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else
BertTrainAccumulationAllReduceEachWithLossScaleCell)
net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell,
accumulation_steps=accumulation_steps,
enable_global_norm=enable_global_norm)
else:
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)

View File

@ -7,6 +7,7 @@ do_shuffle=true
enable_data_sink=true
data_sink_steps=100
accumulation_steps=1
allreduce_post_accumulation=true
save_checkpoint_path=./
save_checkpoint_steps=10000
save_checkpoint_num=1

View File

@ -16,7 +16,8 @@
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
@ -25,7 +26,8 @@ from .adam import AdamWeightDecayForBert
__all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulationAllReduceEachWithLossScaleCell",
"BertTrainAccumulationAllReducePostWithLossScaleCell",
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",

View File

@ -556,11 +556,24 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
return F.depend(ret, succ)
cast = P.Cast()
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
add_grads = C.MultitypeFuncGraph("add_grads")
@add_grads.register("Tensor", "Tensor")
def _add_grads(accu_grad, grad):
return accu_grad + cast(grad, mstype.float32)
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
@accumulate_accu_grads.register("Tensor", "Tensor")
def _accumulate_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
@ -575,13 +588,17 @@ def _reset_accu_grads(accu_grad):
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph. To mimic higher batch size, gradients are
accumulated N times before weight update.
function can be called to create the backward graph.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will only be implemented in the weight updated step,
i.e. the sub-step after gradients accumulated N times.
Args:
network (Cell): The training network. Note that loss function should have been added.
@ -591,7 +608,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
@ -680,7 +697,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
self.cast(scaling_sens,
mstype.float32))
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ)
self.get_status(init)
@ -716,3 +733,151 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
will be overided by backend optimization pass.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32))
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
flag_reduce = self.overflow_reducer(flag_sum)
overflow = self.less_equal(self.base, flag_reduce)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
overflow = self.reshape(overflow, (()))
if is_accu_step:
succ = False
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
succ = F.depend(succ, accu_succ)
else:
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
if self.enable_global_norm:
grads = C.clip_by_global_norm(grads, 1.0, None)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
succ = self.optimizer(grads)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
succ = F.depend(succ, accu_succ)
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)