From e0289cd196f10cb3dddf111ace416067e93f0494 Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Tue, 27 Apr 2021 15:28:03 +0800 Subject: [PATCH] add modelzoo network bgcf testcase --- .../bgcf/test_BGCF_amazon_beauty.py | 72 +++++++++++++++++++ tests/st/model_zoo_tests/utils.py | 2 +- 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 tests/st/model_zoo_tests/bgcf/test_BGCF_amazon_beauty.py diff --git a/tests/st/model_zoo_tests/bgcf/test_BGCF_amazon_beauty.py b/tests/st/model_zoo_tests/bgcf/test_BGCF_amazon_beauty.py new file mode 100644 index 00000000000..6dcea24fb8d --- /dev/null +++ b/tests/st/model_zoo_tests/bgcf/test_BGCF_amazon_beauty.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + +from tests.st.model_zoo_tests import utils + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_BGCF_amazon_beauty(): + cur_path = os.path.dirname(os.path.abspath(__file__)) + model_path = "{}/../../../../model_zoo/official/gnn".format(cur_path) + model_name = "bgcf" + utils.copy_files(model_path, cur_path, model_name) + cur_model_path = os.path.join(cur_path, model_name) + + old_list = ["--datapath=../data_mr"] + new_list = ["--datapath={}".format(os.path.join(utils.data_root, "amazon_beauty/mindrecord_train"))] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "scripts/run_train_ascend.sh")) + old_list = ["default=600,"] + new_list = ["default=50,"] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/config.py")) + old_list = ["context.set_context(device_id=int(parser.device))"] + new_list = ["context.set_context()"] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py")) + exec_network_shell = "cd {}/scripts; bash run_train_ascend.sh".format(model_name) + ret = os.system(exec_network_shell) + assert ret == 0 + + cmd = "ps -ef|grep python |grep train.py|grep amazon_beauty|grep -v grep" + ret = utils.process_check(300, cmd) + assert ret + + log_file = os.path.join(cur_model_path, "scripts/train/log") + pattern1 = r"loss ([\d\.\+]+)\," + loss_list = utils.parse_log_file(pattern1, log_file) + loss_list = loss_list[-5:] + print("last 5 epoch average loss is", sum(loss_list) / len(loss_list)) + assert sum(loss_list) / len(loss_list) < 6400 + + pattern1 = r"cost:([\d\.\+]+)" + epoch_time_list = utils.parse_log_file(pattern1, log_file)[1:] + print("per epoch time:", sum(epoch_time_list) / len(epoch_time_list)) + assert sum(epoch_time_list) / len(epoch_time_list) < 1.9 + + +def test_bgcf_export_mindir(): + cur_path = os.getcwd() + model_path = "{}/../../../../model_zoo/official/gnn".format(cur_path) + model_name = "bgcf" + utils.copy_files(model_path, cur_path, model_name) + cur_model_path = os.path.join(cur_path, model_name) + + ckpt_path = os.path.join(utils.ckpt_root, "bgcf/bgcf_trained.ckpt") + exec_export_shell = "cd {}; python export.py --ckpt_file={} --file_format=MINDIR".format(model_name, ckpt_path) + os.system(exec_export_shell) + assert os.path.exists(os.path.join(cur_model_path, "{}.mindir".format(model_name))) diff --git a/tests/st/model_zoo_tests/utils.py b/tests/st/model_zoo_tests/utils.py index a5b31520310..fe773449d96 100644 --- a/tests/st/model_zoo_tests/utils.py +++ b/tests/st/model_zoo_tests/utils.py @@ -14,7 +14,7 @@ from mindspore import log as logger rank_table_path = "/home/workspace/mindspore_config/hccl/rank_table_8p.json" data_root = "/home/workspace/mindspore_dataset/" -ckpt_root = "/home/workspace/mindspore_ckpt/" +ckpt_root = "/home/workspace/mindspore_dataset/checkpoint" cur_path = os.path.split(os.path.realpath(__file__))[0] geir_root = os.path.join(cur_path, "mindspore_geir") arm_main_path = os.path.join(cur_path, "mindir_310infer_exe")