diff --git a/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc b/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc index 5e30142e3e0..c62716b4a9d 100644 --- a/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc +++ b/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc @@ -232,7 +232,6 @@ void DumpJsonParser::ParseCommonDumpSetting(const nlohmann::json &content) { auto common_dump_settings = CheckJsonKeyExist(content, kCommonDumpSettings); auto dump_mode = CheckJsonKeyExist(*common_dump_settings, kDumpMode); - auto path = CheckJsonKeyExist(*common_dump_settings, kPath); auto net_name = CheckJsonKeyExist(*common_dump_settings, kNetName); auto iteration = CheckJsonKeyExist(*common_dump_settings, kIteration); auto input_output = CheckJsonKeyExist(*common_dump_settings, kInputOutput); @@ -245,7 +244,7 @@ void DumpJsonParser::ParseCommonDumpSetting(const nlohmann::json &content) { } ParseDumpMode(*dump_mode); - ParseDumpPath(*path); + ParseDumpPath(*common_dump_settings); // Pass in the whole json string to parse because the path field is optional. ParseNetName(*net_name); ParseIteration(*iteration); ParseInputOutput(*input_output); @@ -302,15 +301,29 @@ void DumpJsonParser::ParseDumpMode(const nlohmann::json &content) { } void DumpJsonParser::ParseDumpPath(const nlohmann::json &content) { - CheckJsonStringType(content, kPath); - path_ = content; + std::string dump_path; + auto json_iter = content.find(kPath); + // Check if `path` field exists in dump json file. + if (json_iter != content.end()) { + CheckJsonStringType(*json_iter, kPath); + dump_path = *json_iter; + } + if (dump_path.empty()) { + // If no path is found or path is set as empty in dump json file, use MS_DIAGNOSTIC_DATA_PATH/debug_dump as the dump + // path value if the env exists. + dump_path = common::GetEnv("MS_DIAGNOSTIC_DATA_PATH"); + if (dump_path.empty()) { + MS_LOG(EXCEPTION) + << "Dump path is empty. Please set it in dump json file or environment variable `MS_DIAGNOSTIC_DATA_PATH`."; + } else { + dump_path += "/debug_dump"; + } + } + path_ = dump_path; if (!std::all_of(path_.begin(), path_.end(), [](char c) { return ::isalpha(c) || ::isdigit(c) || c == '-' || c == '_' || c == '/'; })) { MS_LOG(EXCEPTION) << "Dump path only support alphabets, digit or {'-', '_', '/'}, but got:" << path_; } - if (path_.empty()) { - MS_LOG(EXCEPTION) << "Dump path is empty"; - } if (path_[0] != '/') { MS_LOG(EXCEPTION) << "Dump path only support absolute path and should start with '/'"; } diff --git a/tests/st/dump/test_data_dump.py b/tests/st/dump/test_data_dump.py index 43234ef226a..29056acf1ae 100644 --- a/tests/st/dump/test_data_dump.py +++ b/tests/st/dump/test_data_dump.py @@ -73,6 +73,9 @@ def test_async_dump(): time.sleep(5) assert len(os.listdir(dump_file_path)) == 1 + # Delete generated dump data + os.system("rm -rf {}".format(dump_path)) + def run_e2e_dump(): if sys.platform != 'linux': @@ -103,6 +106,9 @@ def run_e2e_dump(): assert output.dtype == expect.dtype assert np.array_equal(output, expect) + # Delete generated dump data + os.system("rm -rf {}".format(dump_path)) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -205,3 +211,29 @@ def test_async_dump_net_multi_layer_mode1(): assert value.asnumpy() == dump_result["output0"][index] else: print('not find convert tools msaccucmp.pyc') + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_dump_with_diagnostic_path(): + """ + Test e2e dump when path is not set (set to empty) in dump json file and MS_DIAGNOSTIC_DATA_PATH is set. + Data is expected to be dumped into MS_DIAGNOSTIC_DATA_PATH/debug_dump. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + pwd = os.getcwd() + change_current_dump_json('e2e_dump.json', '') + os.environ['MINDSPORE_DUMP_CONFIG'] = pwd + "/e2e_dump.json" + diagnose_path = pwd + "/e2e_dump" + os.environ['MS_DIAGNOSTIC_DATA_PATH'] = diagnose_path + dump_file_path = diagnose_path + '/debug_dump/rank_0/Net/0/0/' + if os.path.isdir(diagnose_path): + shutil.rmtree(diagnose_path) + add = Net() + add(Tensor(x), Tensor(y)) + assert len(os.listdir(dump_file_path)) == 5 + + # Delete generated dump data + os.system("rm -rf {}".format(diagnose_path))