forked from mindspore-Ecosystem/mindspore
fix bert scripts
This commit is contained in:
parent
5958c4abc6
commit
1be7ad52bb
|
@ -19,6 +19,7 @@ python run_pretrain.py
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import numpy
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
|
@ -142,4 +143,5 @@ def run_pretrain():
|
|||
model = Model(netwithgrads)
|
||||
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
run_pretrain()
|
||||
|
|
|
@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
new_size = ori_dataset_size
|
||||
if enable_data_sink == "true":
|
||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||
|
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(new_repeat_count)
|
||||
ds = ds.repeat(max(new_repeat_count, repeat_count))
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, new_repeat_count
|
||||
|
|
|
@ -32,7 +32,6 @@ from .bert_model import BertModel
|
|||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
_nn_clip_by_norm = nn.ClipByNorm()
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
|
@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
|
@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
vocab_size=21128,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
|
@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=512,
|
||||
vocab_size=30528,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
|
|
|
@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
new_size = ori_dataset_size
|
||||
if enable_data_sink == "true":
|
||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||
|
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(new_repeat_count)
|
||||
ds = ds.repeat(max(new_repeat_count, repeat_count))
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, new_repeat_count
|
||||
|
|
Loading…
Reference in New Issue