!6 add bert example

Merge pull request !6 from yoonlee666/master
This commit is contained in:
mindspore-ci-bot 2020-03-30 19:52:51 +08:00 committed by Gitee
commit 1a3c06a948
3 changed files with 82 additions and 96 deletions

View File

@ -1,55 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 1,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
'DATA_DIR' = "/your/path/examples.tfrecord"
'SCHEMA_DIR' = "/your/path/datasetSchema.json"
'bert_config': BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
})

View File

@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_train_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 0.0,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
# please add your own dataset path
'DATA_DIR': "/your/path/examples.tfrecord",
# please add your own dataset schema path
'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)

View File

@ -14,7 +14,8 @@
# ============================================================================ # ============================================================================
""" """
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei. NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
model currently based on BERT developed by Huawei.
1. Prepare data 1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training: Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \ python ./create_pretraining_data.py \
@ -28,35 +29,29 @@ Following the data preparation as in BERT, run command as below to get dataset f
--random_seed=12345 \ --random_seed=12345 \
--dupe_factor=5 --dupe_factor=5
2. Pretrain 2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py. First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
""" """
import os import os
import pytest
import numpy as np import numpy as np
from numpy import allclose from config import bert_train_cfg, bert_net_cfg
from config import bert_cfg as cfg
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de import mindspore.dataset.engine.datasets as de
import mindspore._c_dataengine as deMap import mindspore._c_dataengine as deMap
from mindspore import context from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb from mindspore.nn.optim import Lamb
from mindspore import log as logger
_current_dir = os.path.dirname(os.path.realpath(__file__)) _current_dir = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = [cfg.DATA_DIR]
SCHEMA_DIR = cfg.SCHEMA_DIR
def me_de_train_dataset(batch_size): def create_train_dataset(batch_size):
"""test me de train dataset""" """create train dataset"""
# apply repeat operations # apply repeat operations
repeat_count = cfg.epoch_size repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
"next_sentence_labels", "masked_lm_positions", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_ids", "masked_lm_weights"]) "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = deMap.TypeCastOp("int32") type_cast_op = deMap.TypeCastOp("int32")
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size):
ds = ds.repeat(repeat_count) ds = ds.repeat(repeat_count)
return ds return ds
def weight_variable(shape): def weight_variable(shape):
"""weight variable""" """weight variable"""
np.random.seed(1) np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones) return Tensor(ones)
def train_bert():
class ModelCallback(Callback): """train bert"""
def __init__(self):
super(ModelCallback, self).__init__()
self.loss_list = []
def step_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs.asnumpy()[0])
logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
def test_bert_tdt():
"""test bert tdt"""
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend") context.set_context(device_target="Ascend")
context.set_context(enable_task_sink=True) context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
parallel_callback = ModelCallback() ds = create_train_dataset(bert_net_cfg.batch_size)
ds = me_de_train_dataset(cfg.bert_config.batch_size) netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
config = cfg.bert_config optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
netwithloss = BertNetworkWithLoss(config, True) start_learning_rate=bert_train_cfg.start_learning_rate,
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate, end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False) warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck) keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False) ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__': if __name__ == '__main__':
test_bert_tdt() train_bert()