use temporary dir as dump dir

This commit is contained in:
yelihua 2021-08-17 20:53:07 +08:00
parent c35d32d45a
commit b4c82be639
1 changed files with 52 additions and 56 deletions

View File

@ -15,6 +15,7 @@
import os
import json
import sys
import tempfile
import time
import shutil
import glob
@ -46,12 +47,11 @@ x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
y = np.array([[7, 8, 9], [10, 11, 12]]).astype(np.float32)
def change_current_dump_json(file_name, dump_path):
def change_current_dump_json(file_name, dump_path, dump_config_path):
with open(file_name, 'r+') as f:
data = json.load(f)
data["common_dump_settings"]["path"] = dump_path
with open(file_name, 'w') as f:
with open(dump_config_path, 'w') as f:
json.dump(data, f)
@ -62,10 +62,12 @@ def change_current_dump_json(file_name, dump_path):
def test_async_dump():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
pwd = os.getcwd()
dump_path = pwd + "/async_dump"
change_current_dump_json('async_dump.json', dump_path)
os.environ['MINDSPORE_DUMP_CONFIG'] = pwd + "/async_dump.json"
dump_file_path = dump_path + '/rank_0/Net/0/0/'
with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
dump_path = os.path.join(tmp_dir, 'async_dump')
dump_config_path = os.path.join(tmp_dir, 'async_dump.json')
change_current_dump_json('async_dump.json', dump_path, dump_config_path)
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
add = Net()
@ -73,23 +75,21 @@ def test_async_dump():
time.sleep(5)
assert len(os.listdir(dump_file_path)) == 1
# Delete generated dump data
os.system("rm -rf {}".format(dump_path))
def run_e2e_dump():
if sys.platform != 'linux':
return
pwd = os.getcwd()
dump_path = pwd + '/e2e_dump'
change_current_dump_json('e2e_dump.json', dump_path)
os.environ['MINDSPORE_DUMP_CONFIG'] = pwd + '/e2e_dump.json'
dump_file_path = dump_path + '/rank_0/Net/0/0/'
with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
dump_path = os.path.join(tmp_dir, 'e2e_dump')
dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json')
change_current_dump_json('e2e_dump.json', dump_path, dump_config_path)
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
add = Net()
add(Tensor(x), Tensor(y))
time.sleep(5)
if context.get_context("device_target") == "Ascend":
assert len(os.listdir(dump_file_path)) == 5
output_name = "Add.Add-op1.0.0.*.output.0.DefaultFormat.npy"
@ -99,16 +99,13 @@ def run_e2e_dump():
else:
assert len(os.listdir(dump_file_path)) == 3
output_name = "Add.Add-op3.0.0.*.output.0.DefaultFormat.npy"
output_path = glob.glob(dump_file_path + output_name)[0]
output_path = glob.glob(os.path.join(dump_file_path, output_name))[0]
real_path = os.path.realpath(output_path)
output = np.load(real_path)
expect = np.array([[8, 10, 12], [14, 16, 18]], np.float32)
assert output.dtype == expect.dtype
assert np.array_equal(output, expect)
# Delete generated dump data
os.system("rm -rf {}".format(dump_path))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@ -257,16 +254,15 @@ def test_dump_with_diagnostic_path():
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
pwd = os.getcwd()
change_current_dump_json('e2e_dump.json', '')
os.environ['MINDSPORE_DUMP_CONFIG'] = pwd + "/e2e_dump.json"
diagnose_path = pwd + "/e2e_dump"
with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json')
change_current_dump_json('e2e_dump.json', '', dump_config_path)
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
diagnose_path = os.path.join(tmp_dir, 'e2e_dump')
os.environ['MS_DIAGNOSTIC_DATA_PATH'] = diagnose_path
dump_file_path = diagnose_path + '/debug_dump/rank_0/Net/0/0/'
dump_file_path = os.path.join(diagnose_path, 'debug_dump', 'rank_0', 'Net', '0', '0')
if os.path.isdir(diagnose_path):
shutil.rmtree(diagnose_path)
add = Net()
add(Tensor(x), Tensor(y))
assert len(os.listdir(dump_file_path)) == 5
# Delete generated dump data
os.system("rm -rf {}".format(diagnose_path))