forked from mindspore-Ecosystem/mindspore
all reduce after each step in gradients accumulation mode for bert
This commit is contained in:
parent
03e655f14a
commit
f0b08e8bff
|
@ -239,30 +239,32 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
|
|||
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
||||
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
||||
[--accumulation_steps N]
|
||||
[--allreduce_post_accumulation ALLREDUCE_POST_ACCUMULATION]
|
||||
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
||||
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
|
||||
[--save_checkpoint_steps N] [--save_checkpoint_num N]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N]
|
||||
|
||||
options:
|
||||
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
|
||||
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_num number of used devices: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
|
||||
--enable_lossscale enable lossscale: "true" | "false", default is "true"
|
||||
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
--data_sink_steps set data sink steps: N, default is 1
|
||||
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
|
||||
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
|
||||
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
|
||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
||||
--save_checkpoint_num number for saving checkpoint files: N, default is 1
|
||||
--train_steps Training Steps: N, default is -1
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
|
||||
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_num number of used devices: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
|
||||
--enable_lossscale enable lossscale: "true" | "false", default is "true"
|
||||
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
--data_sink_steps set data sink steps: N, default is 1
|
||||
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
|
||||
--allreduce_post_accumulation allreduce after accumulation of N steps or after each step: "true" | "false", default is "true"
|
||||
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
|
||||
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
|
||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
||||
--save_checkpoint_num number for saving checkpoint files: N, default is 1
|
||||
--train_steps Training Steps: N, default is -1
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
```
|
||||
|
||||
### Fine-Tuning and Evaluation
|
||||
|
|
|
@ -32,7 +32,9 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
|||
from mindspore import log as logger
|
||||
from mindspore.common import set_seed
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \
|
||||
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReducePostWithLossScaleCell, \
|
||||
BertTrainOneStepWithLossScaleCellForAdam, \
|
||||
AdamWeightDecayForBert
|
||||
from src.dataset import create_bert_dataset
|
||||
from src.config import cfg, bert_net_cfg
|
||||
|
@ -122,6 +124,8 @@ def run_pretrain():
|
|||
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--accumulation_steps", type=int, default="1",
|
||||
help="Accumulating gradients N times before weight update, default is 1.")
|
||||
parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"],
|
||||
help="Whether to allreduce after accumulation of N steps or after each step, default is true.")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
||||
|
@ -207,8 +211,9 @@ def run_pretrain():
|
|||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
|
||||
if args_opt.accumulation_steps <= 1:
|
||||
accumulation_steps = args_opt.accumulation_steps
|
||||
enable_global_norm = cfg.enable_global_norm
|
||||
if accumulation_steps <= 1:
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
|
@ -216,11 +221,13 @@ def run_pretrain():
|
|||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
accumulation_steps = args_opt.accumulation_steps
|
||||
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell,
|
||||
accumulation_steps=accumulation_steps,
|
||||
enable_global_norm=cfg.enable_global_norm)
|
||||
allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
|
||||
net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else
|
||||
BertTrainAccumulationAllReduceEachWithLossScaleCell)
|
||||
net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell,
|
||||
accumulation_steps=accumulation_steps,
|
||||
enable_global_norm=enable_global_norm)
|
||||
else:
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ do_shuffle=true
|
|||
enable_data_sink=true
|
||||
data_sink_steps=100
|
||||
accumulation_steps=1
|
||||
allreduce_post_accumulation=true
|
||||
save_checkpoint_path=./
|
||||
save_checkpoint_steps=10000
|
||||
save_checkpoint_num=1
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
||||
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
||||
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
BertTrainAccumulateStepsWithLossScaleCell
|
||||
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReducePostWithLossScaleCell
|
||||
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||
|
@ -25,7 +26,8 @@ from .adam import AdamWeightDecayForBert
|
|||
__all__ = [
|
||||
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
|
||||
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
|
||||
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulationAllReduceEachWithLossScaleCell",
|
||||
"BertTrainAccumulationAllReducePostWithLossScaleCell",
|
||||
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",
|
||||
|
|
|
@ -556,11 +556,24 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
|||
return F.depend(ret, succ)
|
||||
|
||||
cast = P.Cast()
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
add_grads = C.MultitypeFuncGraph("add_grads")
|
||||
|
||||
|
||||
@add_grads.register("Tensor", "Tensor")
|
||||
def _add_grads(accu_grad, grad):
|
||||
return accu_grad + cast(grad, mstype.float32)
|
||||
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
|
||||
|
||||
@accumulate_accu_grads.register("Tensor", "Tensor")
|
||||
def _accumulate_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
@ -575,13 +588,17 @@ def _reset_accu_grads(accu_grad):
|
|||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
||||
|
||||
class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
||||
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph. To mimic higher batch size, gradients are
|
||||
accumulated N times before weight update.
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will only be implemented in the weight updated step,
|
||||
i.e. the sub-step after gradients accumulated N times.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
|
@ -591,7 +608,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
|
@ -680,7 +697,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
|
||||
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
|
||||
mean_loss = F.depend(mean_loss, accu_succ)
|
||||
|
||||
self.get_status(init)
|
||||
|
@ -716,3 +733,151 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
|
||||
will be overided by backend optimization pass.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
|
||||
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
flag_reduce = self.overflow_reducer(flag_sum)
|
||||
overflow = self.less_equal(self.base, flag_reduce)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
else:
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
|
Loading…
Reference in New Issue