forked from mindspore-Ecosystem/mindspore
Mimic higher batch size by accumulating gradients N times before weight update
This commit is contained in:
parent
5b722a1037
commit
40fc11e9a4
|
@ -63,6 +63,7 @@ check_bprop = P.CheckBprop()
|
||||||
equal = P.Equal()
|
equal = P.Equal()
|
||||||
not_equal = P.NotEqual()
|
not_equal = P.NotEqual()
|
||||||
assign_sub = P.AssignSub()
|
assign_sub = P.AssignSub()
|
||||||
|
assign_add = P.AssignAdd()
|
||||||
assign = P.Assign()
|
assign = P.Assign()
|
||||||
square = P.Square()
|
square = P.Square()
|
||||||
sqrt = P.Sqrt()
|
sqrt = P.Sqrt()
|
||||||
|
|
|
@ -123,6 +123,7 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
|
||||||
[--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET]
|
[--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET]
|
||||||
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
||||||
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
||||||
|
[--accumulation_steps N]
|
||||||
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
||||||
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
|
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
|
||||||
[--save_checkpoint_steps N] [--save_checkpoint_num N]
|
[--save_checkpoint_steps N] [--save_checkpoint_num N]
|
||||||
|
@ -139,6 +140,7 @@ options:
|
||||||
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
||||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||||
--data_sink_steps set data sink steps: N, default is 1
|
--data_sink_steps set data sink steps: N, default is 1
|
||||||
|
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
|
||||||
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
|
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
|
||||||
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
|
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
|
||||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
||||||
|
|
|
@ -30,7 +30,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||||
|
BertTrainAccumulateStepsWithLossScaleCell
|
||||||
from src.dataset import create_bert_dataset
|
from src.dataset import create_bert_dataset
|
||||||
from src.config import cfg, bert_net_cfg
|
from src.config import cfg, bert_net_cfg
|
||||||
from src.utils import LossCallBack, BertLearningRate
|
from src.utils import LossCallBack, BertLearningRate
|
||||||
|
@ -51,6 +52,8 @@ def run_pretrain():
|
||||||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||||
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default="1",
|
||||||
|
help="Accumulating gradients N times before weight update, default is 1.")
|
||||||
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
||||||
|
@ -98,6 +101,16 @@ def run_pretrain():
|
||||||
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
||||||
bert_net_cfg.compute_type = mstype.float32
|
bert_net_cfg.compute_type = mstype.float32
|
||||||
|
|
||||||
|
if args_opt.accumulation_steps > 1:
|
||||||
|
logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
|
||||||
|
logger.info("global batch size: {}".format(bert_net_cfg.batch_size * args_opt.accumulation_steps))
|
||||||
|
if args_opt.enable_data_sink == "true":
|
||||||
|
args_opt.data_sink_steps *= args_opt.accumulation_steps
|
||||||
|
logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
|
||||||
|
if args_opt.enable_save_ckpt == "true":
|
||||||
|
args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
|
||||||
|
logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps))
|
||||||
|
|
||||||
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||||
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||||
|
|
||||||
|
@ -157,8 +170,15 @@ def run_pretrain():
|
||||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||||
scale_factor=cfg.scale_factor,
|
scale_factor=cfg.scale_factor,
|
||||||
scale_window=cfg.scale_window)
|
scale_window=cfg.scale_window)
|
||||||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
|
||||||
scale_update_cell=update_cell)
|
if args_opt.accumulation_steps <= 1:
|
||||||
|
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||||
|
scale_update_cell=update_cell)
|
||||||
|
else:
|
||||||
|
accumulation_steps = args_opt.accumulation_steps
|
||||||
|
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||||
|
scale_update_cell=update_cell,
|
||||||
|
accumulation_steps=accumulation_steps)
|
||||||
else:
|
else:
|
||||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ enable_lossscale=true
|
||||||
do_shuffle=true
|
do_shuffle=true
|
||||||
enable_data_sink=true
|
enable_data_sink=true
|
||||||
data_sink_steps=100
|
data_sink_steps=100
|
||||||
|
accumulation_steps=1
|
||||||
save_checkpoint_path=./checkpoint/
|
save_checkpoint_path=./checkpoint/
|
||||||
save_checkpoint_steps=10000
|
save_checkpoint_steps=10000
|
||||||
save_checkpoint_num=1
|
save_checkpoint_num=1
|
||||||
|
|
|
@ -39,6 +39,7 @@ python ${PROJECT_DIR}/../run_pretrain.py \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="true" \
|
||||||
--data_sink_steps=1 \
|
--data_sink_steps=1 \
|
||||||
|
--accumulation_steps=1 \
|
||||||
--load_checkpoint_path="" \
|
--load_checkpoint_path="" \
|
||||||
--save_checkpoint_steps=10000 \
|
--save_checkpoint_steps=10000 \
|
||||||
--save_checkpoint_num=1 \
|
--save_checkpoint_num=1 \
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
"""Bert Init."""
|
"""Bert Init."""
|
||||||
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
||||||
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
||||||
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||||
|
BertTrainAccumulateStepsWithLossScaleCell
|
||||||
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||||
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||||
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||||
|
@ -23,7 +24,8 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell",
|
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
|
||||||
|
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
|
||||||
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||||
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||||
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
|
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
|
||||||
|
|
|
@ -438,3 +438,164 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
succ = self.optimizer(grads)
|
succ = self.optimizer(grads)
|
||||||
ret = (loss, cond, scaling_sens)
|
ret = (loss, cond, scaling_sens)
|
||||||
return F.depend(ret, succ)
|
return F.depend(ret, succ)
|
||||||
|
|
||||||
|
|
||||||
|
cast = P.Cast()
|
||||||
|
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||||
|
|
||||||
|
|
||||||
|
@update_accu_grads.register("Tensor", "Tensor")
|
||||||
|
def _update_accu_grads(accu_grad, grad):
|
||||||
|
succ = True
|
||||||
|
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||||
|
|
||||||
|
|
||||||
|
zeroslike = P.ZerosLike()
|
||||||
|
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||||
|
|
||||||
|
|
||||||
|
@reset_accu_grads.register("Tensor")
|
||||||
|
def _reset_accu_grads(accu_grad):
|
||||||
|
succ = True
|
||||||
|
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||||
|
|
||||||
|
|
||||||
|
class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph. To mimic higher batch size, gradients are
|
||||||
|
accumulated N times before weight update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||||
|
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||||
|
batch_size * accumulation_steps. Default: 1.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1):
|
||||||
|
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.weights = optimizer.parameters
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.accumulation_steps = accumulation_steps
|
||||||
|
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||||
|
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||||
|
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
|
||||||
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||||
|
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow")
|
||||||
|
self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss")
|
||||||
|
|
||||||
|
self.grad = C.GradOperation('grad',
|
||||||
|
get_by_list=True,
|
||||||
|
sens_param=True)
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = F.identity
|
||||||
|
self.degree = 1
|
||||||
|
if self.reducer_flag:
|
||||||
|
self.degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
|
self.overflow_reducer = F.identity
|
||||||
|
if self.is_distributed:
|
||||||
|
self.overflow_reducer = P.AllReduce()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
|
self.logical_or = P.LogicalOr()
|
||||||
|
self.not_equal = P.NotEqual()
|
||||||
|
self.select = P.Select()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.loss_scale = None
|
||||||
|
self.loss_scaling_manager = scale_update_cell
|
||||||
|
if scale_update_cell:
|
||||||
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
|
name="loss_scale")
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
sens=None):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights)
|
||||||
|
if sens is None:
|
||||||
|
scaling_sens = self.loss_scale
|
||||||
|
else:
|
||||||
|
scaling_sens = sens
|
||||||
|
|
||||||
|
# update accumulation parameters
|
||||||
|
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||||
|
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||||
|
self.loss = self.select(is_accu_step, self.loss + loss, loss)
|
||||||
|
mean_loss = self.loss / self.local_step
|
||||||
|
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||||
|
|
||||||
|
# alloc status and clear should be right before gradoperation
|
||||||
|
init = self.alloc_status()
|
||||||
|
self.clear_before_grad(init)
|
||||||
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
self.cast(scaling_sens,
|
||||||
|
mstype.float32))
|
||||||
|
|
||||||
|
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
|
||||||
|
mean_loss = F.depend(mean_loss, accu_succ)
|
||||||
|
|
||||||
|
self.get_status(init)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
overflow = self.less_equal(self.base, flag_sum)
|
||||||
|
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||||
|
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||||
|
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||||
|
|
||||||
|
if is_accu_step:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(self.accu_grads)
|
||||||
|
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||||
|
F.control_depend(grads, accu_overflow)
|
||||||
|
overflow = self.less_equal(self.base, accu_overflow)
|
||||||
|
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||||
|
overflow = F.depend(overflow, accu_succ)
|
||||||
|
overflow = self.reshape(overflow, (()))
|
||||||
|
if sens is None:
|
||||||
|
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||||
|
if overflow:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
|
||||||
|
ret = (mean_loss, overflow, scaling_sens)
|
||||||
|
return F.depend(ret, succ)
|
||||||
|
|
|
@ -50,7 +50,7 @@ cfg = edict({
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Including two kinds of network: \
|
Including two kinds of network: \
|
||||||
base: Goole BERT-base(the base version of BERT model).
|
base: Google BERT-base(the base version of BERT model).
|
||||||
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
||||||
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
||||||
'''
|
'''
|
||||||
|
|
Loading…
Reference in New Issue