retrieve modelzoo network gnmt_v2 and ncf testcase

This commit is contained in:
anzhengqi 2021-06-04 15:50:16 +08:00
parent 475386e338
commit 6108bb444b
5 changed files with 0 additions and 298 deletions

View File

@ -18,8 +18,3 @@ packaging >= 20.0
pycocotools >= 2.0.2 # for st test
tables >= 3.6.1 # for st test
psutil >= 5.7.0
subword-nmt>=0.3.7 # for st test
sacrebleu>=1.4.14 # for st test
sacremoses>=0.0.35 # for st test
absl-py>=0.10.0 # for st test
six>=1.15.0 # for st test

View File

@ -1,103 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train and eval api."""
import os
import argparse
import pickle
import datetime
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from config import GNMTConfig
from train import train_parallel
from src.gnmt_model import infer
from src.gnmt_model.bleu_calculate import bleu_calculate
from src.dataset.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description='GNMT train and eval.')
# train
parser.add_argument("--config_train", type=str, required=True,
help="model config json file path.")
parser.add_argument("--pre_train_dataset", type=str, required=True,
help="pre-train dataset address.")
# eval
parser.add_argument("--config_test", type=str, required=True,
help="model config json file path.")
parser.add_argument("--test_dataset", type=str, required=True,
help="test dataset address.")
parser.add_argument("--existed_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.")
parser.add_argument("--bpe_codes", type=str, required=True,
help="bpe codes to use.")
parser.add_argument("--test_tgt", type=str, required=True,
default=None,
help="data file of the test target")
parser.add_argument("--output", type=str, required=False,
default="./output.npz",
help="result file path.")
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
def _check_args(config):
if not os.path.exists(config):
raise FileNotFoundError("`config` is not existed.")
if not isinstance(config, str):
raise ValueError("`config` must be type of str.")
if __name__ == '__main__':
start_time = datetime.datetime.now()
_rank_size = os.getenv('RANK_SIZE')
args, _ = parser.parse_known_args()
# train
_check_args(args.config_train)
_config_train = get_config(args.config_train)
_config_train.pre_train_dataset = args.pre_train_dataset
set_seed(_config_train.random_seed)
assert _rank_size is not None and int(_rank_size) > 1
if _rank_size is not None and int(_rank_size) > 1:
train_parallel(_config_train)
# eval
_check_args(args.config_test)
_config_test = get_config(args.config_test)
_config_test.test_dataset = args.test_dataset
_config_test.existed_ckpt = args.existed_ckpt
result = infer(_config_test)
with open(args.output, "wb") as f:
pickle.dump(result, f, 1)
result_npy_addr = args.output
vocab = args.vocab
bpe_codes = args.bpe_codes
test_tgt = args.test_tgt
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
print(f"BLEU scores is :{scores}")
end_time = datetime.datetime.now()
cost_time = (end_time - start_time).seconds
print(f"Cost time is {cost_time}s.")
assert scores >= 23.8
assert cost_time < 10800.0
print("----done!----")

View File

@ -1,78 +0,0 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh test_gnmt_v2.sh \
GNMT_ADDR RANK_TABLE_ADDR PRE_TRAIN_DATASET TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
echo "for example:"
echo "sh test_gnmt_v2.sh \
/home/workspace/gnmt_v2 \
/home/workspace/rank_table_8p.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord \
/home/workspace/dataset_menu/newstest2014.en.mindrecord \
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
/home/workspace/wmt16_de_en/vocab.bpe.32000 \
/home/workspace/wmt16_de_en/bpe.32000 \
/home/workspace/wmt16_de_en/newstest2014.de"
echo "It is better to use absolute path."
echo "=============================================================================================================="
GNMT_ADDR=$1
RANK_TABLE_ADDR=$2
# train dataset addr
PRE_TRAIN_DATASET=$3
# eval dataset addr
TEST_DATASET=$4
EXISTED_CKPT_PATH=$5
VOCAB_ADDR=$6
BPE_CODE_ADDR=$7
TEST_TARGET=$8
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_TABLE_FILE=$RANK_TABLE_ADDR
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_ADDR
echo $RANK_TABLE_FILE
export RANK_SIZE=8
export GLOG_v=2
for((i=0;i<=7;i++));
do
rm -rf ${current_exec_path}/device$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit
cp ${current_exec_path}/*.py .
cp ${GNMT_ADDR}/*.py .
cp -r ${GNMT_ADDR}/src .
cp -r ${GNMT_ADDR}/config .
export RANK_ID=$i
export DEVICE_ID=$i
python test_gnmt_v2.py \
--config_train=${GNMT_ADDR}/config/config.json \
--pre_train_dataset=$PRE_TRAIN_DATASET \
--config_test=${GNMT_ADDR}/config/config_test.json \
--test_dataset=$TEST_DATASET \
--existed_ckpt=$EXISTED_CKPT_PATH \
--vocab=$VOCAB_ADDR \
--bpe_codes=$BPE_CODE_ADDR \
--test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 &
cd ${current_exec_path} || exit
done
cd ${current_exec_path} || exit

View File

@ -1,64 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
from tests.st.model_zoo_tests import utils
@pytest.mark.level1
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single
def test_gnmtv2_WMT_English_German():
cur_path = os.path.dirname(os.path.abspath(__file__))
model_path = "{}/../../../../model_zoo/official/nlp".format(cur_path)
model_name = "gnmt_v2"
utils.copy_files(model_path, cur_path, model_name)
cur_model_path = os.path.join(cur_path, model_name)
old_list = ['dataset_sink_mode=config.dataset_sink_mode']
new_list = ['dataset_sink_mode=config.dataset_sink_mode, sink_size=25']
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py"))
old_list = ['"epochs": 6,']
new_list = ['"epochs": 4,']
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "config/config.json"))
mindrecord_file = "wmt16_de_en/train_tok_mindrecord/train.tok.clean.bpe.32000.en.mindrecord"
exec_network_shell = "cd {}/scripts; sh run_distributed_train_ascend.sh {} {}"\
.format(model_name, utils.rank_table_path, os.path.join(utils.data_root, mindrecord_file))
ret = os.system(exec_network_shell)
assert ret == 0
cmd = "ps -ef | grep python | grep train.py | grep train.tok.clean.bpe | grep -v grep"
ret = utils.process_check(120, cmd)
assert ret
log_file = os.path.join(cur_model_path, "scripts/device{}/log_gnmt_network{}.log")
for i in range(8):
per_step_time = utils.get_perf_data(log_file.format(i, i))
print("per_step_time is", per_step_time)
assert per_step_time < 330.0
log_file = os.path.join(cur_model_path, "scripts/device{}/loss.log")
loss_list = []
for i in range(8):
pattern1 = r"loss\: ([\d\.\+]+)\,"
loss = utils.parse_log_file(pattern1, log_file.format(i))
print("loss is", loss)
loss_list.append(loss[-1])
print("loss_list is", loss_list)
print(sum(loss_list) / len(loss_list))
assert sum(loss_list) / len(loss_list) < 260

View File

@ -1,48 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
from tests.st.model_zoo_tests import utils
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_ncf():
cur_path = os.path.dirname(os.path.abspath(__file__))
model_path = "{}/../../../../model_zoo/official/recommend".format(cur_path)
model_name = "ncf"
utils.copy_files(model_path, cur_path, model_name)
cur_model_path = os.path.join(cur_path, model_name)
old_list = ["train_epochs 20"]
new_list = ["train_epochs 4"]
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "scripts/run_train.sh"))
old_list = ["with open(cache_path, \\\"wb\\\")", "pickle.dump"]
new_list = ["\\# with open(cache_path, \\\"wb\\\")", "\\# pickle.dump"]
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/dataset.py"))
dataset_path = os.path.join(utils.data_root, "MovieLens")
exec_network_shell = "cd ncf; bash scripts/run_train.sh {0} checkpoint/ > train.log 2>&1 &"\
.format(dataset_path)
os.system(exec_network_shell)
cmd = "ps -ef|grep python|grep train.py|grep train_epochs|grep -v grep"
ret = utils.process_check(100, cmd)
assert ret
log_file = os.path.join(cur_model_path, "train.log")
per_step_time = utils.get_perf_data(log_file)
assert per_step_time < 2.0
loss = utils.get_loss_data_list(log_file)[-1]
assert loss < 0.33