forked from mindspore-Ecosystem/mindspore
commit network lenet and resnet
This commit is contained in:
parent
8de54f3355
commit
fc1546b950
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from tests.st.model_zoo_tests import utils
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet_MNIST():
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = "{}/../../../../model_zoo/official/cv".format(cur_path)
|
||||
model_name = "lenet"
|
||||
utils.copy_files(model_path, cur_path, model_name)
|
||||
cur_model_path = os.path.join(cur_path, model_name)
|
||||
train_log = os.path.join(cur_model_path, "train_ascend.log")
|
||||
ckpt_file = os.path.join(cur_model_path, "ckpt/checkpoint_lenet-10_1875.ckpt")
|
||||
infer_log = os.path.join(cur_model_path, "infer_ascend.log")
|
||||
dataset_path = os.path.join(utils.data_root, "mnist")
|
||||
exec_network_shell = "cd {0}; python train.py --data_path={1} > {2} 2>&1"\
|
||||
.format(model_name, dataset_path, train_log)
|
||||
ret = os.system(exec_network_shell)
|
||||
assert ret == 0
|
||||
exec_network_shell = "cd {0}; python eval.py --data_path={1} --ckpt_path={2} > {3} 2>&1"\
|
||||
.format(model_name, dataset_path, ckpt_file, infer_log)
|
||||
ret = os.system(exec_network_shell)
|
||||
assert ret == 0
|
||||
|
||||
per_step_time = utils.get_perf_data(train_log)
|
||||
print("per_step_time is", per_step_time)
|
||||
assert per_step_time < 1.3
|
||||
|
||||
pattern = r"'Accuracy': ([\d\.]+)}"
|
||||
acc = utils.parse_log_file(pattern, infer_log)
|
||||
print("acc is", acc)
|
||||
assert acc[0] > 0.98
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from mindspore import log as logger
|
||||
from tests.st.model_zoo_tests import utils
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_resnet50_cifar10_ascend():
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = "{}/../../../../model_zoo/official/cv".format(cur_path)
|
||||
model_name = "resnet"
|
||||
utils.copy_files(model_path, cur_path, model_name)
|
||||
cur_model_path = os.path.join(cur_path, "resnet")
|
||||
old_list = ["total_epochs=config.epoch_size", "config.epoch_size - config.pretrain_epoch_size"]
|
||||
new_list = ["total_epochs=10", "10"]
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py"))
|
||||
dataset_path = os.path.join(utils.data_root, "cifar-10-batches-bin")
|
||||
exec_network_shell = "cd resnet/scripts; bash run_distribute_train.sh resnet50 cifar10 {} {}"\
|
||||
.format(utils.rank_table_path, dataset_path)
|
||||
os.system(exec_network_shell)
|
||||
cmd = "ps -ef | grep python | grep train.py | grep -v grep"
|
||||
ret = utils.process_check(100, cmd)
|
||||
assert ret
|
||||
log_file = os.path.join(cur_model_path, "scripts/train_parallel{}/log")
|
||||
for i in range(8):
|
||||
per_step_time = utils.get_perf_data(log_file.format(i))
|
||||
assert per_step_time < 20.0
|
||||
loss_list = []
|
||||
for i in range(8):
|
||||
loss = utils.get_loss_data_list(log_file.format(i))
|
||||
loss_list.append(loss[-1])
|
||||
assert sum(loss_list) / len(loss_list) < 0.70
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_resnet50_cifar10_gpu():
|
||||
cur_path = os.getcwd()
|
||||
model_path = "{}/../../../../model_zoo/official/cv".format(cur_path)
|
||||
model_name = "resnet"
|
||||
utils.copy_files(model_path, cur_path, model_name)
|
||||
cur_model_path = os.path.join(cur_path, "resnet")
|
||||
old_list = ["total_epochs=config.epoch_size", "config.epoch_size - config.pretrain_epoch_size"]
|
||||
new_list = ["total_epochs=10", "10"]
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py"))
|
||||
dataset_path = os.path.join(utils.data_root, "cifar-10-batches-bin")
|
||||
exec_network_shell = "cd resnet/scripts; sh run_distribute_train_gpu.sh resnet50 cifar10 {}".format(dataset_path)
|
||||
logger.warning("cmd [{}] is running...".format(exec_network_shell))
|
||||
os.system(exec_network_shell)
|
||||
cmd = "ps -ef | grep python | grep train.py | grep -v grep"
|
||||
ret = utils.process_check(100, cmd)
|
||||
assert ret
|
||||
log_file = os.path.join(cur_model_path, "scripts/train_parallel/log")
|
||||
pattern = r"per step time: ([\d\.]+) ms"
|
||||
step_time_list = utils.parse_log_file(pattern, log_file)[8:]
|
||||
per_step_time = sum(step_time_list) / len(step_time_list)
|
||||
print("step time list is", step_time_list)
|
||||
assert per_step_time < 115
|
||||
loss_list = utils.get_loss_data_list(log_file)[-8:]
|
||||
print("loss_list is", loss_list)
|
||||
assert sum(loss_list) / len(loss_list) < 0.70
|
|
@ -0,0 +1,114 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
""" File Description
|
||||
Details
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import re
|
||||
from mindspore import log as logger
|
||||
|
||||
rank_table_path = "/home/workspace/mindspore_config/hccl/rank_table_8p.json"
|
||||
data_root = "/home/workspace/mindspore_dataset/"
|
||||
ckpt_root = "/home/workspace/mindspore_ckpt/"
|
||||
cur_path = os.path.split(os.path.realpath(__file__))[0]
|
||||
geir_root = os.path.join(cur_path, "mindspore_geir")
|
||||
arm_main_path = os.path.join(cur_path, "mindir_310infer_exe")
|
||||
model_zoo_path = os.path.join(cur_path, "../../../model_zoo")
|
||||
|
||||
|
||||
def copy_files(from_, to_, model_name):
|
||||
if not os.path.exists(os.path.join(from_, model_name)):
|
||||
raise ValueError("There is no file or path", os.path.join(from_, model_name))
|
||||
if os.path.exists(os.path.join(to_, model_name)):
|
||||
shutil.rmtree(os.path.join(to_, model_name))
|
||||
return os.system("cp -r {0} {1}".format(os.path.join(from_, model_name), to_))
|
||||
|
||||
|
||||
def exec_sed_command(old_list, new_list, file):
|
||||
if isinstance(old_list, str):
|
||||
old_list = [old_list]
|
||||
if isinstance(new_list, str):
|
||||
old_list = [new_list]
|
||||
if len(old_list) != len(new_list):
|
||||
raise ValueError("len(old_list) should be equal to len(new_list)")
|
||||
for old, new in zip(old_list, new_list):
|
||||
ret = os.system('sed -i "s#{0}#{1}#g" {2}'.format(old, new, file))
|
||||
if ret != 0:
|
||||
raise ValueError('exec `sed -i "s#{0}#{1}#g" {2}` failed.'.format(old, new, file))
|
||||
return ret
|
||||
|
||||
|
||||
def process_check(cycle_time, cmd, wait_time=5):
|
||||
for i in range(cycle_time):
|
||||
time.sleep(wait_time)
|
||||
sub = subprocess.Popen(args="{}".format(cmd), shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, universal_newlines=True)
|
||||
stdout_data, _ = sub.communicate()
|
||||
if not stdout_data:
|
||||
logger.info("process execute success.")
|
||||
return True
|
||||
logger.warning("process is running, please wait {}".format(i))
|
||||
logger.error("process execute execute timeout.")
|
||||
return False
|
||||
|
||||
|
||||
def get_perf_data(log_path, search_str="per step time", cmd=None):
|
||||
if cmd is None:
|
||||
get_step_times_cmd = r"""grep -a "{0}" {1}|egrep -v "loss|\]|\["|awk '{{print $(NF-1)}}'""" \
|
||||
.format(search_str, log_path)
|
||||
else:
|
||||
get_step_times_cmd = cmd
|
||||
sub = subprocess.Popen(args="{}".format(get_step_times_cmd), shell=True,
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, universal_newlines=True)
|
||||
stdout, _ = sub.communicate()
|
||||
if sub.returncode != 0:
|
||||
raise RuntimeError("exec {} failed".format(cmd))
|
||||
logger.info("execute {} success".format(cmd))
|
||||
stdout = stdout.strip().split("\n")
|
||||
step_time_list = list(map(float, stdout[1:]))
|
||||
if not step_time_list:
|
||||
cmd = "cat {}".format(log_path)
|
||||
os.system(cmd)
|
||||
raise RuntimeError("step_time_list is empty")
|
||||
per_step_time = sum(step_time_list) / len(step_time_list)
|
||||
return per_step_time
|
||||
|
||||
|
||||
def get_loss_data_list(log_path, search_str="loss is", cmd=None):
|
||||
if cmd is None:
|
||||
loss_value_cmd = """ grep -a '{}' {}| awk '{{print $NF}}' """.format(search_str, log_path)
|
||||
else:
|
||||
loss_value_cmd = cmd
|
||||
sub = subprocess.Popen(args="{}".format(loss_value_cmd), shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, universal_newlines=True)
|
||||
stdout, _ = sub.communicate()
|
||||
if sub.returncode != 0:
|
||||
raise RuntimeError("get loss from {} failed".format(log_path))
|
||||
logger.info("execute {} success".format(cmd))
|
||||
stdout = stdout.strip().split("\n")
|
||||
loss_list = list(map(float, stdout))
|
||||
if not loss_list:
|
||||
cmd = "cat {}".format(log_path)
|
||||
os.system(cmd)
|
||||
raise RuntimeError("loss_list is empty")
|
||||
return loss_list
|
||||
|
||||
|
||||
def parse_log_file(pattern, log_path):
|
||||
value_list = []
|
||||
with open(log_path, "r") as file:
|
||||
for line in file.readlines():
|
||||
match_result = re.search(pattern, line)
|
||||
if match_result is not None:
|
||||
value_list.append(float(match_result.group(1)))
|
||||
if not value_list:
|
||||
print("pattern is", pattern)
|
||||
cmd = "cat {}".format(log_path)
|
||||
os.system(cmd)
|
||||
return value_list
|
Loading…
Reference in New Issue