forked from mindspore-Ecosystem/mindspore
modify pre-traning script
This commit is contained in:
parent
29dc8048f2
commit
c8221cce4d
|
@ -1,57 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
bert_train_cfg = edict({
|
||||
'epoch_size': 10,
|
||||
'num_warmup_steps': 0,
|
||||
'start_learning_rate': 1e-4,
|
||||
'end_learning_rate': 0.0,
|
||||
'decay_steps': 1000,
|
||||
'power': 10.0,
|
||||
'save_checkpoint_steps': 2000,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_prefix': "checkpoint_bert",
|
||||
# please add your own dataset path
|
||||
'DATA_DIR': "/your/path/examples.tfrecord",
|
||||
# please add your own dataset schema path
|
||||
'SCHEMA_DIR': "/your/path/datasetSchema.json"
|
||||
})
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
|
||||
model currently based on BERT developed by Huawei.
|
||||
1. Prepare data
|
||||
Following the data preparation as in BERT, run command as below to get dataset for training:
|
||||
python ./create_pretraining_data.py \
|
||||
--input_file=./sample_text.txt \
|
||||
--output_file=./examples.tfrecord \
|
||||
--vocab_file=./your/path/vocab.txt \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=128 \
|
||||
--max_predictions_per_seq=20 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=5
|
||||
2. Pretrain
|
||||
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from config import bert_train_cfg, bert_net_cfg
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
|
||||
from mindspore.nn.optim import Lamb
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
def create_train_dataset(batch_size):
|
||||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = bert_train_cfg.epoch_size
|
||||
ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
return ds
|
||||
|
||||
def weight_variable(shape):
|
||||
"""weight variable"""
|
||||
np.random.seed(1)
|
||||
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
def train_bert():
|
||||
"""train bert"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_context(enable_task_sink=True)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
ds = create_train_dataset(bert_net_cfg.batch_size)
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
|
||||
start_learning_rate=bert_train_cfg.start_learning_rate,
|
||||
end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
|
||||
warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
|
||||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
|
||||
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_bert()
|
|
@ -4,19 +4,25 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
|
|||
|
||||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download the zhwiki dataset from <https://dumps.wikimedia.org/zhwiki> for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wiliextractor). Convert the dataset to TFRecord format and move the files to a specified path.
|
||||
- Download the zhwiki dataset from <https://dumps.wikimedia.org/zhwiki> for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wil
|
||||
kiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
|
||||
- Download the CLUE dataset from <https://www.cluebenchmarks.com> for fine-tuning and evaluation.
|
||||
> Notes:
|
||||
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
|
||||
|
||||
## Running the Example
|
||||
### Pre-Training
|
||||
- Set options in `config.py`. Make sure the 'DATA_DIR'(path to the dataset) and 'SCHEMA_DIR'(path to the json schema file) are set to your own path. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
|
||||
- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
|
||||
|
||||
- Run `run_pretrain.py` for pre-training of BERT-base and BERT-NEZHA model.
|
||||
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||
|
||||
``` bash
|
||||
python run_pretrain.py --backend=ms
|
||||
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH
|
||||
```
|
||||
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||
|
||||
``` bash
|
||||
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH
|
||||
```
|
||||
|
||||
### Fine-Tuning
|
||||
|
@ -40,30 +46,42 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
|
|||
## Usage
|
||||
### Pre-Training
|
||||
```
|
||||
usage: run_pretrain.py [--backend BACKEND]
|
||||
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
|
||||
[--enable_task_sink ENABLE_TASK_SINK] [--enable_loop_sink ENABLE_LOOP_SINK]
|
||||
[--enable_mem_reuse ENABLE_MEM_REUSE] [--enable_save_ckpt ENABLE_SAVE_CKPT]
|
||||
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
||||
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
|
||||
[--save_checkpoint_steps N] [--save_checkpoint_num N]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
|
||||
|
||||
optional parameters:
|
||||
--backend, BACKEND MindSpore backend: ms
|
||||
options:
|
||||
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_num number of used devices: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_task_sink enable task sink: "true" | "false", default is "true"
|
||||
--enable_loop_sink enable loop sink: "true" | "false", default is "true"
|
||||
--enable_mem_reuse enable memory reuse: "true" | "false", default is "true"
|
||||
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
|
||||
--enable_lossscale enable lossscale: "true" | "false", default is "true"
|
||||
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
--data_sink_steps set data sink steps: N, default is 1
|
||||
--checkpoint_path path to save checkpoint files: PATH, default is ""
|
||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
||||
--save_checkpoint_num number for saving checkpoint files: N, default is 1
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
```
|
||||
|
||||
## Options and Parameters
|
||||
It contains of parameters of BERT model and options for training, which is set in file `config.py`, `finetune_config.py` and `evaluation_config.py` respectively.
|
||||
### Options:
|
||||
```
|
||||
Pre-Training:
|
||||
bert_network version of BERT model: base | large, default is base
|
||||
epoch_size repeat counts of training: N, default is 40
|
||||
dataset_sink_mode use dataset sink mode or not: True | False, default is True
|
||||
do_shuffle shuffle the dataset or not: True | False, default is True
|
||||
do_train_with_lossscale use lossscale or not: True | False, default is True
|
||||
loss_scale_value initial value of loss scale: N, default is 2^32
|
||||
scale_factor factor used to update loss scale: N, default is 2
|
||||
scale_window steps for once updatation of loss scale: N, default is 1000
|
||||
save_checkpoint_steps steps to save a checkpoint: N, default is 2000
|
||||
keep_checkpoint_max numbers to save checkpoint: N, default is 1
|
||||
init_ckpt checkpoint file to load: PATH, default is ""
|
||||
data_dir dataset file to load: PATH, default is "/your/path/cn-wiki-128"
|
||||
schema_dir dataset schema file to load: PATH, default is "your/path/datasetSchema.json"
|
||||
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
|
||||
|
||||
Fine-Tuning:
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_pretrain.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
cfg = edict({
|
||||
'bert_network': 'base',
|
||||
'loss_scale_value': 2**32,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 1000,
|
||||
'optimizer': 'Lamb',
|
||||
'AdamWeightDecayDynamicLR': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 0.0,
|
||||
'power': 5.0,
|
||||
'weight_decay': 1e-5,
|
||||
'eps': 1e-6,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'start_learning_rate': 3e-5,
|
||||
'end_learning_rate': 0.0,
|
||||
'power': 10.0,
|
||||
'warmup_steps': 10000,
|
||||
'weight_decay': 0.01,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: False,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
})
|
||||
if cfg.bert_network == 'base':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
else:
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in run_pretrain.py
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
from config import bert_net_cfg
|
||||
|
||||
|
||||
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
|
||||
data_sink_steps=1, data_dir=None, schema_dir=None):
|
||||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = epoch_size
|
||||
files = os.listdir(data_dir)
|
||||
data_files = []
|
||||
for file_name in files:
|
||||
data_files.append(data_dir+file_name)
|
||||
ds = de.TFRecordDataset(data_files, schema_dir,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
new_size = ori_dataset_size
|
||||
if enable_data_sink == "true":
|
||||
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||
ds.set_dataset_size(new_size)
|
||||
repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH"
|
||||
echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json /path/mindspore"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
SCHEMA_DIR=$4
|
||||
MINDSPORE_PATH=$6
|
||||
|
||||
export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$5
|
||||
export RANK_SIZE=$1
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
start=`expr $i \* 12`
|
||||
end=`expr $start \+ 11`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
taskset -c $cmdopt python ../run_pretrain.py \
|
||||
--distribute="true" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--device_num=$RANK_SIZE \
|
||||
--enable_task_sink="true" \
|
||||
--enable_loop_sink="true" \
|
||||
--enable_mem_reuse="true" \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1 \
|
||||
--checkpoint_path="" \
|
||||
--save_checkpoint_steps=1000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################pre_train bert example on zh-wiki########################
|
||||
python run_pretrain.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
|
||||
from dataset import create_bert_dataset
|
||||
from config import cfg, bert_net_cfg
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
with open("./loss.log", "a+") as f:
|
||||
f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
f.write('\n')
|
||||
|
||||
def run_pretrain():
|
||||
"""pre-train bert_clue"""
|
||||
parser = argparse.ArgumentParser(description='bert pre_training')
|
||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.")
|
||||
parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
|
||||
parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
|
||||
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
||||
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
||||
"default is 1000.")
|
||||
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"),
|
||||
enable_loop_sink=(args_opt.enable_loop_sink == "true"),
|
||||
enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
|
||||
if args_opt.distribute == "true":
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_context(enable_hccl=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
D.init()
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
context.set_context(enable_hccl=False)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink,
|
||||
args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir)
|
||||
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
|
||||
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
|
||||
eps=cfg.Lamb.eps, decay_filter=cfg.Lamb.decay_filter)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
|
||||
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
|
||||
decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
|
||||
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
|
||||
power=cfg.AdamWeightDecayDynamicLR.power,
|
||||
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
|
||||
eps=cfg.AdamWeightDecayDynamicLR.eps)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
|
||||
format(cfg.optimizer))
|
||||
callback = [LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(netwithloss, param_dict)
|
||||
|
||||
if args_opt.enable_lossscale == "true":
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
|
||||
model = Model(netwithgrads)
|
||||
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
||||
if __name__ == '__main__':
|
||||
run_pretrain()
|
|
@ -0,0 +1,46 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH"
|
||||
echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json /path/mindspore"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
SCHEMA_DIR=$4
|
||||
MINDSPORE_PATH=$5
|
||||
export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH
|
||||
|
||||
python run_pretrain.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_task_sink="true" \
|
||||
--enable_loop_sink="true" \
|
||||
--enable_mem_reuse="true" \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1 \
|
||||
--checkpoint_path="" \
|
||||
--save_checkpoint_steps=1000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
Loading…
Reference in New Issue