From 6de897372ab0be22d4d1bf9622b7476f29ad5222 Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Tue, 27 Apr 2021 11:33:09 +0800 Subject: [PATCH] add modelzoo network gnmtv2 testcase --- requirements.txt | 3 + .../gnmt_v2/test_gnmtv2_WMT_English_German.py | 64 +++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 tests/st/model_zoo_tests/gnmt_v2/test_gnmtv2_WMT_English_German.py diff --git a/requirements.txt b/requirements.txt index 01133772fea..f1b8979284d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,6 @@ packaging >= 20.0 pycocotools >= 2.0.0 # for st test tables >= 3.6.1 # for st test psutil >= 5.6.1 +subword-nmt>=0.3.7 # for st test +sacrebleu>=1.4.14 # for st test +sacremoses>=0.0.35 # for st test diff --git a/tests/st/model_zoo_tests/gnmt_v2/test_gnmtv2_WMT_English_German.py b/tests/st/model_zoo_tests/gnmt_v2/test_gnmtv2_WMT_English_German.py new file mode 100644 index 00000000000..56e83a2484b --- /dev/null +++ b/tests/st/model_zoo_tests/gnmt_v2/test_gnmtv2_WMT_English_German.py @@ -0,0 +1,64 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + +from tests.st.model_zoo_tests import utils + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_single +def test_gnmtv2_WMT_English_German(): + cur_path = os.path.dirname(os.path.abspath(__file__)) + model_path = "{}/../../../../model_zoo/official/nlp".format(cur_path) + model_name = "gnmt_v2" + utils.copy_files(model_path, cur_path, model_name) + cur_model_path = os.path.join(cur_path, model_name) + + old_list = ['dataset_sink_mode=config.dataset_sink_mode'] + new_list = ['dataset_sink_mode=config.dataset_sink_mode, sink_size=100'] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py")) + old_list = ['"epochs": 6,'] + new_list = ['"epochs": 4,'] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "config/config.json")) + + mindrecord_file = "wmt16_de_en/train_tok_mindrecord/train.tok.clean.bpe.32000.en.mindrecord" + exec_network_shell = "cd {}/scripts; sh run_distributed_train_ascend.sh {} {}"\ + .format(model_name, utils.rank_table_path, os.path.join(utils.data_root, mindrecord_file)) + ret = os.system(exec_network_shell) + assert ret == 0 + + cmd = "ps -ef | grep python | grep train.py | grep train.tok.clean.bpe | grep -v grep" + ret = utils.process_check(120, cmd) + assert ret + + log_file = os.path.join(cur_model_path, "scripts/device{}/log_gnmt_network{}.log") + for i in range(8): + per_step_time = utils.get_perf_data(log_file.format(i, i)) + print("per_step_time is", per_step_time) + assert per_step_time < 270.0 + + log_file = os.path.join(cur_model_path, "scripts/device{}/loss.log") + loss_list = [] + for i in range(8): + pattern1 = r"loss\: ([\d\.\+]+)\," + loss = utils.parse_log_file(pattern1, log_file.format(i)) + print("loss is", loss) + loss_list.append(loss[-1]) + print("loss_list is", loss_list) + print(sum(loss_list) / len(loss_list)) + assert sum(loss_list) / len(loss_list) < 120