update st test case: add error log after fail

Signed-off-by: zhushujing <zhushujing@huawei.com>
This commit is contained in:
zhushujing 2021-03-29 21:14:08 +08:00
parent b802f25563
commit fc0499055d
2 changed files with 4 additions and 1 deletions

View File

@ -24,4 +24,5 @@ import pytest
def test_broadcast_auto_parallel(): def test_broadcast_auto_parallel():
sh_path = os.path.split(os.path.realpath(__file__))[0] sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_broadcast_auto_parallel.sh") ret = os.system(f"sh {sh_path}/run_broadcast_auto_parallel.sh")
os.system(f"grep -E 'ERROR|error' {sh_path}/lenet_broadcast*/test_lenet_auto_parallel_broadcast_8p_log*log -C 3")
assert ret == 0 assert ret == 0

View File

@ -36,7 +36,6 @@ def read_validateir_file(path_folder):
filename = find_newest_validateir_file(path_folder) filename = find_newest_validateir_file(path_folder)
with open(os.path.join(filename), 'r') as f: with open(os.path.join(filename), 'r') as f:
contend = f.read() contend = f.read()
clean_all_ir_files(path_folder)
return contend return contend
@ -130,10 +129,12 @@ def test_sit_auto_mix_precision_model_o0():
contend = read_validateir_file('./test_amp_o0') contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend) castnum = re.findall("Cast", contend)
assert len(castnum) == 5 assert len(castnum) == 5
clean_all_ir_files('./test_amp_o0')
model.predict(Tensor(input_data)) model.predict(Tensor(input_data))
contend = read_validateir_file('./test_amp_o0') contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend) castnum = re.findall("Cast", contend)
assert len(castnum) == 11 assert len(castnum) == 11
clean_all_ir_files('./test_amp_o0')
@pytest.mark.level0 @pytest.mark.level0
@ -164,6 +165,7 @@ def test_sit_auto_mix_precision_model_o2():
contend = read_validateir_file('./test_amp_o2') contend = read_validateir_file('./test_amp_o2')
castnum = re.findall("Cast", contend) castnum = re.findall("Cast", contend)
assert len(castnum) == 14 assert len(castnum) == 14
clean_all_ir_files('./test_amp_o2')
out_graph = model.predict(Tensor(input_data)) out_graph = model.predict(Tensor(input_data))
# pynative mode # pynative mode