forked from mindspore-Ecosystem/mindspore
!3269 fix bert scripts to adapt the new concept of repeat_count in minddata
Merge pull request !3269 from chenhaozhe/fix-bert
This commit is contained in:
commit
1823cc06c3
|
@ -64,7 +64,6 @@ def run_pretrain():
|
|||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
|
@ -99,47 +98,49 @@ def run_pretrain():
|
|||
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
||||
bert_net_cfg.compute_type = mstype.float32
|
||||
|
||||
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
|
||||
ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.enable_data_sink, args_opt.data_sink_steps,
|
||||
args_opt.data_dir, args_opt.schema_dir)
|
||||
new_repeat_count = args_opt.epoch_size
|
||||
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
|
||||
if args_opt.train_steps > 0:
|
||||
new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
else:
|
||||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=ds.get_dataset_size() * new_repeat_count,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params}]
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=ds.get_dataset_size() * new_repeat_count,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||
format(cfg.optimizer))
|
||||
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
|
@ -148,19 +149,22 @@ def run_pretrain():
|
|||
|
||||
if args_opt.load_checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
load_param_into_net(netwithloss, param_dict)
|
||||
load_param_into_net(net_with_loss, param_dict)
|
||||
|
||||
if args_opt.enable_lossscale == "true":
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
|
||||
model = Model(net_with_grads)
|
||||
model.train(new_repeat_count, ds, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
model = Model(netwithgrads)
|
||||
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
run_pretrain()
|
||||
|
|
|
@ -23,11 +23,9 @@ from mindspore import log as logger
|
|||
from .config import bert_net_cfg
|
||||
|
||||
|
||||
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
|
||||
data_sink_steps=1, data_dir=None, schema_dir=None):
|
||||
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = epoch_size
|
||||
files = os.listdir(data_dir)
|
||||
data_files = []
|
||||
for file_name in files:
|
||||
|
@ -40,11 +38,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
new_size = ori_dataset_size
|
||||
if enable_data_sink == "true":
|
||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||
ds.set_dataset_size(new_size)
|
||||
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
|
@ -55,8 +48,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, new_repeat_count
|
||||
logger.info("repeat count: {}".format(ds.get_repeat_count()))
|
||||
return ds
|
||||
|
||||
|
||||
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
|
||||
|
|
|
@ -18,6 +18,7 @@ Functional Cells used in Bert finetune and evaluation.
|
|||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
|
Loading…
Reference in New Issue