add st for bert_thor

This commit is contained in:
wangmin 2020-09-15 18:52:27 +08:00
parent 5ffec7efaf
commit d96ecd2740
1 changed files with 210 additions and 0 deletions

View File

@ -0,0 +1,210 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test bert thor performance with 8p on mlperf dataset"""
import os
import time
from multiprocessing import Process, Queue
import pytest
import numpy as np
import mindspore.dataset as dataset
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore import context
from mindspore import log as logger
from mindspore.train.callback import Callback
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from model_zoo.official.nlp.bert_thor.src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepCell
from model_zoo.official.nlp.bert_thor.src.bert_net_config import bert_net_cfg
from model_zoo.official.nlp.bert_thor.src.config import cfg
from model_zoo.official.nlp.bert_thor.src.lr_generator import get_bert_lr, get_bert_damping
from model_zoo.official.nlp.bert_thor.src.model_thor import Model
from model_zoo.official.nlp.bert_thor.src.thor_for_bert_arg import THOR
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json"
DATASET_PATH = "/home/workspace/mindspore_dataset/bert/thor/en-wiki-512_test_first1wan"
load_checkpoint_path = ""
data_sink_steps = 100
train_steps = 200
batch_size = 12
np.random.seed(1)
dataset.config.set_seed(1)
os.environ['GLOG_v'] = str(2)
class TimeMonitor(Callback):
"""Time Monitor."""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_mseconds_list = []
self.per_step_mseconds_list = []
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
self.epoch_mseconds_list.append(epoch_mseconds)
per_step_mseconds = epoch_mseconds / self.data_size
self.per_step_mseconds_list.append(per_step_mseconds)
print("epoch: {}, per_step_mseconds are {}".format(cb_params.cur_epoch_num, str(per_step_mseconds)), flush=True)
class LossCallback(Callback):
def __init__(self):
super(LossCallback, self).__init__()
self.loss_list = []
def epoch_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs.asnumpy())
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)), flush=True)
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
files = os.listdir(data_dir)
data_files = []
for file_name in files:
if "tfrecord" in file_name:
data_files.append(os.path.join(data_dir, file_name))
data_files = sorted(data_files)
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False,
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="masked_lm_ids")
ds = ds.map(operations=type_cast_op, input_columns="masked_lm_positions")
ds = ds.map(operations=type_cast_op, input_columns="next_sentence_labels")
ds = ds.map(operations=type_cast_op, input_columns="segment_ids")
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeat count: {}".format(ds.get_repeat_count()))
return ds
def train_process_bert_thor(q, device_id, epoch_size, device_num):
for i in range(device_num):
os.system("rm -rf " + str(i))
os.system("mkdir " + str(device_id))
os.chdir(str(device_id))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(max_call_depth=3000)
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
os.environ['RANK_ID'] = str(device_id)
os.environ['RANK_SIZE'] = str(device_num)
D.init()
rank = device_id % device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
bert_net_cfg.num_hidden_layers = 2
ds = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
new_repeat_count = epoch_size * ds.get_dataset_size() // data_sink_steps
new_repeat_count = min(new_repeat_count, train_steps // data_sink_steps)
lr = get_bert_lr()
damping = get_bert_damping()
optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()),
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
bert_net_cfg.batch_size, damping)
time_monitor_callback = TimeMonitor(data_sink_steps)
loss_callback = LossCallback()
callback = [time_monitor_callback, loss_callback]
if load_checkpoint_path:
param_dict = load_checkpoint(load_checkpoint_path)
load_param_into_net(net_with_loss, param_dict)
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads, frequency=cfg.Thor.frequency)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=True, sink_size=data_sink_steps)
loss_list = loss_callback.loss_list
per_step_mseconds = time_monitor_callback.per_step_mseconds_list
q.put({'loss': loss_list, 'cost': per_step_mseconds})
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_bert_thor_mlperf_8p():
"""test bert thor mlperf 8p"""
q = Queue()
device_num = 8
epoch_size = 2
process = []
for i in range(device_num):
device_id = i
process.append(Process(target=train_process_bert_thor, args=(q, device_id, epoch_size, device_num)))
for i in range(device_num):
process[i].start()
print("Waiting for all subprocesses done...")
for i in range(device_num):
process[i].join()
sum_loss_list = []
sum_cost_list = []
for _ in range(train_steps // data_sink_steps):
sum_loss_list.append(0.0)
sum_cost_list.append(0.0)
for _ in range(device_num):
output = q.get()
loss_list = output['loss']
cost_list = output['cost']
sum_loss_list = np.sum([loss_list, sum_loss_list], axis=0)
sum_cost_list = np.sum([cost_list, sum_cost_list], axis=0)
for j in range(train_steps // data_sink_steps):
print("epoch: ", j, "sum_loss: ", sum_loss_list[j], "sum_cost: ", sum_cost_list[j])
mean_loss = sum_loss_list[-1] / device_num
mean_cost = sum_cost_list[-1] / device_num
for i in range(device_num):
os.system("rm -rf " + str(i))
print("End training...")
assert mean_cost < 51
assert mean_loss < 8.5
if __name__ == '__main__':
test_bert_thor_mlperf_8p()