fix bert scripts

This commit is contained in:
chenhaozhe 2020-06-08 19:23:36 +08:00
parent 5958c4abc6
commit 1be7ad52bb
5 changed files with 10 additions and 7 deletions

View File

@ -19,6 +19,7 @@ python run_pretrain.py
import os
import argparse
import numpy
import mindspore.communication.management as D
from mindspore import context
from mindspore.train.model import Model
@ -142,4 +143,5 @@ def run_pretrain():
model = Model(netwithgrads)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
numpy.random.seed(0)
run_pretrain()

View File

@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(new_repeat_count)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count

View File

@ -32,7 +32,6 @@ from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
_nn_clip_by_norm = nn.ClipByNorm()
clip_grad = C.MultitypeFuncGraph("clip_grad")
@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad

View File

@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21136,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21136,
vocab_size=21128,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=512,
vocab_size=30528,
vocab_size=30522,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,

View File

@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(new_repeat_count)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count