forked from mindspore-Ecosystem/mindspore
fix bert scripts
This commit is contained in:
parent
5958c4abc6
commit
1be7ad52bb
|
@ -19,6 +19,7 @@ python run_pretrain.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import numpy
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
|
@ -142,4 +143,5 @@ def run_pretrain():
|
||||||
model = Model(netwithgrads)
|
model = Model(netwithgrads)
|
||||||
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
numpy.random.seed(0)
|
||||||
run_pretrain()
|
run_pretrain()
|
||||||
|
|
|
@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
||||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||||
shard_equal_rows=True)
|
shard_equal_rows=True)
|
||||||
ori_dataset_size = ds.get_dataset_size()
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print('origin dataset size: ', ori_dataset_size)
|
||||||
new_size = ori_dataset_size
|
new_size = ori_dataset_size
|
||||||
if enable_data_sink == "true":
|
if enable_data_sink == "true":
|
||||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||||
|
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
||||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||||
ds = ds.repeat(new_repeat_count)
|
ds = ds.repeat(max(new_repeat_count, repeat_count))
|
||||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||||
return ds, new_repeat_count
|
return ds, new_repeat_count
|
||||||
|
|
|
@ -32,7 +32,6 @@ from .bert_model import BertModel
|
||||||
GRADIENT_CLIP_TYPE = 1
|
GRADIENT_CLIP_TYPE = 1
|
||||||
GRADIENT_CLIP_VALUE = 1.0
|
GRADIENT_CLIP_VALUE = 1.0
|
||||||
|
|
||||||
_nn_clip_by_norm = nn.ClipByNorm()
|
|
||||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
||||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
else:
|
else:
|
||||||
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
return new_grad
|
return new_grad
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21136,
|
vocab_size=21128,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
|
@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21136,
|
vocab_size=21128,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=24,
|
num_hidden_layers=24,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
|
@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
seq_length=512,
|
seq_length=512,
|
||||||
vocab_size=30528,
|
vocab_size=30522,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=24,
|
num_hidden_layers=24,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
|
|
|
@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
||||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||||
shard_equal_rows=True)
|
shard_equal_rows=True)
|
||||||
ori_dataset_size = ds.get_dataset_size()
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print('origin dataset size: ', ori_dataset_size)
|
||||||
new_size = ori_dataset_size
|
new_size = ori_dataset_size
|
||||||
if enable_data_sink == "true":
|
if enable_data_sink == "true":
|
||||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||||
|
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
||||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||||
ds = ds.repeat(new_repeat_count)
|
ds = ds.repeat(max(new_repeat_count, repeat_count))
|
||||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||||
return ds, new_repeat_count
|
return ds, new_repeat_count
|
||||||
|
|
Loading…
Reference in New Issue