retrieve modelzoo network gnmt_v2 and ncf testcase
This commit is contained in:
parent
475386e338
commit
6108bb444b
|
@ -18,8 +18,3 @@ packaging >= 20.0
|
|||
pycocotools >= 2.0.2 # for st test
|
||||
tables >= 3.6.1 # for st test
|
||||
psutil >= 5.7.0
|
||||
subword-nmt>=0.3.7 # for st test
|
||||
sacrebleu>=1.4.14 # for st test
|
||||
sacremoses>=0.0.35 # for st test
|
||||
absl-py>=0.10.0 # for st test
|
||||
six>=1.15.0 # for st test
|
|
@ -1,103 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train and eval api."""
|
||||
import os
|
||||
import argparse
|
||||
import pickle
|
||||
import datetime
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from config import GNMTConfig
|
||||
from train import train_parallel
|
||||
from src.gnmt_model import infer
|
||||
from src.gnmt_model.bleu_calculate import bleu_calculate
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description='GNMT train and eval.')
|
||||
# train
|
||||
parser.add_argument("--config_train", type=str, required=True,
|
||||
help="model config json file path.")
|
||||
parser.add_argument("--pre_train_dataset", type=str, required=True,
|
||||
help="pre-train dataset address.")
|
||||
# eval
|
||||
parser.add_argument("--config_test", type=str, required=True,
|
||||
help="model config json file path.")
|
||||
parser.add_argument("--test_dataset", type=str, required=True,
|
||||
help="test dataset address.")
|
||||
parser.add_argument("--existed_ckpt", type=str, required=True,
|
||||
help="existed checkpoint address.")
|
||||
parser.add_argument("--vocab", type=str, required=True,
|
||||
help="Vocabulary to use.")
|
||||
parser.add_argument("--bpe_codes", type=str, required=True,
|
||||
help="bpe codes to use.")
|
||||
parser.add_argument("--test_tgt", type=str, required=True,
|
||||
default=None,
|
||||
help="data file of the test target")
|
||||
parser.add_argument("--output", type=str, required=False,
|
||||
default="./output.npz",
|
||||
help="result file path.")
|
||||
|
||||
|
||||
def get_config(config):
|
||||
config = GNMTConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
|
||||
def _check_args(config):
|
||||
if not os.path.exists(config):
|
||||
raise FileNotFoundError("`config` is not existed.")
|
||||
if not isinstance(config, str):
|
||||
raise ValueError("`config` must be type of str.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = datetime.datetime.now()
|
||||
_rank_size = os.getenv('RANK_SIZE')
|
||||
args, _ = parser.parse_known_args()
|
||||
# train
|
||||
_check_args(args.config_train)
|
||||
_config_train = get_config(args.config_train)
|
||||
_config_train.pre_train_dataset = args.pre_train_dataset
|
||||
set_seed(_config_train.random_seed)
|
||||
assert _rank_size is not None and int(_rank_size) > 1
|
||||
if _rank_size is not None and int(_rank_size) > 1:
|
||||
train_parallel(_config_train)
|
||||
# eval
|
||||
_check_args(args.config_test)
|
||||
_config_test = get_config(args.config_test)
|
||||
_config_test.test_dataset = args.test_dataset
|
||||
_config_test.existed_ckpt = args.existed_ckpt
|
||||
result = infer(_config_test)
|
||||
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump(result, f, 1)
|
||||
|
||||
result_npy_addr = args.output
|
||||
vocab = args.vocab
|
||||
bpe_codes = args.bpe_codes
|
||||
test_tgt = args.test_tgt
|
||||
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
|
||||
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
|
||||
print(f"BLEU scores is :{scores}")
|
||||
end_time = datetime.datetime.now()
|
||||
cost_time = (end_time - start_time).seconds
|
||||
print(f"Cost time is {cost_time}s.")
|
||||
assert scores >= 23.8
|
||||
assert cost_time < 10800.0
|
||||
print("----done!----")
|
|
@ -1,78 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh test_gnmt_v2.sh \
|
||||
GNMT_ADDR RANK_TABLE_ADDR PRE_TRAIN_DATASET TEST_DATASET EXISTED_CKPT_PATH \
|
||||
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
|
||||
echo "for example:"
|
||||
echo "sh test_gnmt_v2.sh \
|
||||
/home/workspace/gnmt_v2 \
|
||||
/home/workspace/rank_table_8p.json \
|
||||
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord \
|
||||
/home/workspace/dataset_menu/newstest2014.en.mindrecord \
|
||||
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
|
||||
/home/workspace/wmt16_de_en/vocab.bpe.32000 \
|
||||
/home/workspace/wmt16_de_en/bpe.32000 \
|
||||
/home/workspace/wmt16_de_en/newstest2014.de"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
GNMT_ADDR=$1
|
||||
RANK_TABLE_ADDR=$2
|
||||
# train dataset addr
|
||||
PRE_TRAIN_DATASET=$3
|
||||
# eval dataset addr
|
||||
TEST_DATASET=$4
|
||||
EXISTED_CKPT_PATH=$5
|
||||
VOCAB_ADDR=$6
|
||||
BPE_CODE_ADDR=$7
|
||||
TEST_TARGET=$8
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_ADDR
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_ADDR
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export GLOG_v=2
|
||||
|
||||
for((i=0;i<=7;i++));
|
||||
do
|
||||
rm -rf ${current_exec_path}/device$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i || exit
|
||||
cp ${current_exec_path}/*.py .
|
||||
cp ${GNMT_ADDR}/*.py .
|
||||
cp -r ${GNMT_ADDR}/src .
|
||||
cp -r ${GNMT_ADDR}/config .
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python test_gnmt_v2.py \
|
||||
--config_train=${GNMT_ADDR}/config/config.json \
|
||||
--pre_train_dataset=$PRE_TRAIN_DATASET \
|
||||
--config_test=${GNMT_ADDR}/config/config_test.json \
|
||||
--test_dataset=$TEST_DATASET \
|
||||
--existed_ckpt=$EXISTED_CKPT_PATH \
|
||||
--vocab=$VOCAB_ADDR \
|
||||
--bpe_codes=$BPE_CODE_ADDR \
|
||||
--test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
cd ${current_exec_path} || exit
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from tests.st.model_zoo_tests import utils
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_gnmtv2_WMT_English_German():
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = "{}/../../../../model_zoo/official/nlp".format(cur_path)
|
||||
model_name = "gnmt_v2"
|
||||
utils.copy_files(model_path, cur_path, model_name)
|
||||
cur_model_path = os.path.join(cur_path, model_name)
|
||||
|
||||
old_list = ['dataset_sink_mode=config.dataset_sink_mode']
|
||||
new_list = ['dataset_sink_mode=config.dataset_sink_mode, sink_size=25']
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py"))
|
||||
old_list = ['"epochs": 6,']
|
||||
new_list = ['"epochs": 4,']
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "config/config.json"))
|
||||
|
||||
mindrecord_file = "wmt16_de_en/train_tok_mindrecord/train.tok.clean.bpe.32000.en.mindrecord"
|
||||
exec_network_shell = "cd {}/scripts; sh run_distributed_train_ascend.sh {} {}"\
|
||||
.format(model_name, utils.rank_table_path, os.path.join(utils.data_root, mindrecord_file))
|
||||
ret = os.system(exec_network_shell)
|
||||
assert ret == 0
|
||||
|
||||
cmd = "ps -ef | grep python | grep train.py | grep train.tok.clean.bpe | grep -v grep"
|
||||
ret = utils.process_check(120, cmd)
|
||||
assert ret
|
||||
|
||||
log_file = os.path.join(cur_model_path, "scripts/device{}/log_gnmt_network{}.log")
|
||||
for i in range(8):
|
||||
per_step_time = utils.get_perf_data(log_file.format(i, i))
|
||||
print("per_step_time is", per_step_time)
|
||||
assert per_step_time < 330.0
|
||||
|
||||
log_file = os.path.join(cur_model_path, "scripts/device{}/loss.log")
|
||||
loss_list = []
|
||||
for i in range(8):
|
||||
pattern1 = r"loss\: ([\d\.\+]+)\,"
|
||||
loss = utils.parse_log_file(pattern1, log_file.format(i))
|
||||
print("loss is", loss)
|
||||
loss_list.append(loss[-1])
|
||||
print("loss_list is", loss_list)
|
||||
print(sum(loss_list) / len(loss_list))
|
||||
assert sum(loss_list) / len(loss_list) < 260
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from tests.st.model_zoo_tests import utils
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ncf():
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = "{}/../../../../model_zoo/official/recommend".format(cur_path)
|
||||
model_name = "ncf"
|
||||
utils.copy_files(model_path, cur_path, model_name)
|
||||
cur_model_path = os.path.join(cur_path, model_name)
|
||||
old_list = ["train_epochs 20"]
|
||||
new_list = ["train_epochs 4"]
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "scripts/run_train.sh"))
|
||||
old_list = ["with open(cache_path, \\\"wb\\\")", "pickle.dump"]
|
||||
new_list = ["\\# with open(cache_path, \\\"wb\\\")", "\\# pickle.dump"]
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/dataset.py"))
|
||||
dataset_path = os.path.join(utils.data_root, "MovieLens")
|
||||
exec_network_shell = "cd ncf; bash scripts/run_train.sh {0} checkpoint/ > train.log 2>&1 &"\
|
||||
.format(dataset_path)
|
||||
os.system(exec_network_shell)
|
||||
cmd = "ps -ef|grep python|grep train.py|grep train_epochs|grep -v grep"
|
||||
ret = utils.process_check(100, cmd)
|
||||
assert ret
|
||||
log_file = os.path.join(cur_model_path, "train.log")
|
||||
per_step_time = utils.get_perf_data(log_file)
|
||||
assert per_step_time < 2.0
|
||||
loss = utils.get_loss_data_list(log_file)[-1]
|
||||
assert loss < 0.33
|
Loading…
Reference in New Issue