Mimic higher batch size by accumulating gradients N times before weight update

This commit is contained in:
shibeiji 2020-08-25 10:54:33 +08:00
parent 5b722a1037
commit 40fc11e9a4
8 changed files with 195 additions and 7 deletions

View File

@ -63,6 +63,7 @@ check_bprop = P.CheckBprop()
equal = P.Equal() equal = P.Equal()
not_equal = P.NotEqual() not_equal = P.NotEqual()
assign_sub = P.AssignSub() assign_sub = P.AssignSub()
assign_add = P.AssignAdd()
assign = P.Assign() assign = P.Assign()
square = P.Square() square = P.Square()
sqrt = P.Sqrt() sqrt = P.Sqrt()

View File

@ -123,6 +123,7 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
[--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET] [--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--accumulation_steps N]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH] [--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N] [--save_checkpoint_steps N] [--save_checkpoint_num N]
@ -139,6 +140,7 @@ options:
--do_shuffle enable shuffle: "true" | "false", default is "true" --do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true" --enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1 --data_sink_steps set data sink steps: N, default is 1
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
--save_checkpoint_path path to save checkpoint files: PATH, default is "" --save_checkpoint_path path to save checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is "" --load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000

View File

@ -30,7 +30,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell
from src.dataset import create_bert_dataset from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack, BertLearningRate from src.utils import LossCallBack, BertLearningRate
@ -51,6 +52,8 @@ def run_pretrain():
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--accumulation_steps", type=int, default="1",
help="Accumulating gradients N times before weight update, default is 1.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
@ -98,6 +101,16 @@ def run_pretrain():
logger.warning('Gpu only support fp32 temporarily, run with fp32.') logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32 bert_net_cfg.compute_type = mstype.float32
if args_opt.accumulation_steps > 1:
logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
logger.info("global batch size: {}".format(bert_net_cfg.batch_size * args_opt.accumulation_steps))
if args_opt.enable_data_sink == "true":
args_opt.data_sink_steps *= args_opt.accumulation_steps
logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
if args_opt.enable_save_ckpt == "true":
args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps))
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
@ -157,8 +170,15 @@ def run_pretrain():
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor, scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window) scale_window=cfg.scale_window)
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell) if args_opt.accumulation_steps <= 1:
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
else:
accumulation_steps = args_opt.accumulation_steps
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell,
accumulation_steps=accumulation_steps)
else: else:
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)

View File

@ -6,6 +6,7 @@ enable_lossscale=true
do_shuffle=true do_shuffle=true
enable_data_sink=true enable_data_sink=true
data_sink_steps=100 data_sink_steps=100
accumulation_steps=1
save_checkpoint_path=./checkpoint/ save_checkpoint_path=./checkpoint/
save_checkpoint_steps=10000 save_checkpoint_steps=10000
save_checkpoint_num=1 save_checkpoint_num=1

View File

@ -39,6 +39,7 @@ python ${PROJECT_DIR}/../run_pretrain.py \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="true" \ --enable_data_sink="true" \
--data_sink_steps=1 \ --data_sink_steps=1 \
--accumulation_steps=1 \
--load_checkpoint_path="" \ --load_checkpoint_path="" \
--save_checkpoint_steps=10000 \ --save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \ --save_checkpoint_num=1 \

View File

@ -15,7 +15,8 @@
"""Bert Init.""" """Bert Init."""
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulateStepsWithLossScaleCell
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
@ -23,7 +24,8 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
__all__ = [ __all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell", "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
"BertSelfAttention", "BertTransformer", "EmbeddingLookup", "BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",

View File

@ -438,3 +438,164 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
succ = self.optimizer(grads) succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens) ret = (loss, cond, scaling_sens)
return F.depend(ret, succ) return F.depend(ret, succ)
cast = P.Cast()
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
zeroslike = P.ZerosLike()
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph. To mimic higher batch size, gradients are
accumulated N times before weight update.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1):
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow")
self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss")
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.loss = self.select(is_accu_step, self.loss + loss, loss)
mean_loss = self.loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
if is_accu_step:
succ = False
else:
# apply grad reducer on grads
grads = self.grad_reducer(self.accu_grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
accu_overflow = self.overflow_reducer(accu_overflow)
F.control_depend(grads, accu_overflow)
overflow = self.less_equal(self.base, accu_overflow)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
overflow = F.depend(overflow, accu_succ)
overflow = self.reshape(overflow, (()))
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)

View File

@ -50,7 +50,7 @@ cfg = edict({
''' '''
Including two kinds of network: \ Including two kinds of network: \
base: Goole BERT-base(the base version of BERT model). base: Google BERT-base(the base version of BERT model).
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
Functional Relative Posetional Encoding as an effective positional encoding scheme). Functional Relative Posetional Encoding as an effective positional encoding scheme).
''' '''