diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index 3f9d851057f..948fd7da189 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -26,6 +26,7 @@ #include #include #include "pybind11/embed.h" +#include "pybind11/stl.h" #ifdef ONLINE_DBG_MODE #include "debug/common.h" #include "debug/debugger/debugger.h" @@ -549,6 +550,7 @@ void DebugServices::ConvertToHostFormat(const std::map files_to_convert_in_dir; + std::vector files_after_convert_in_dir; std::string dump_key = d.first; for (auto const &file_name : d.second) { bool already_converted = false; @@ -567,26 +569,19 @@ void DebugServices::ConvertToHostFormat(const std::map(input_file_o, delim)); - std::string input_files = input_file_o.str(); - MS_LOG(INFO) << "Ops to convert: " << input_files; - if (input_files != "") { + MS_LOG(INFO) << "Number of files to convert: " << files_to_convert_in_dir.size(); + if (!files_to_convert_in_dir.empty()) { // Look for the installation path to the conver_async package. If not found, throw exception and terminate the // later task. try { auto pkg = pybind11::module::import("mindspore.offline_debug.convert_async"); - std::string convert_pkg_path = pkg.attr("__file__").cast(); - MS_LOG(INFO) << "The file for converting async dump data is in " << convert_pkg_path; - std::string convert_command = "python " + convert_pkg_path + " -out " + dump_key + " -t " + file_format + - " -d " + dump_key + " -f NCHW -l " + input_files; - (void)(system(convert_command.c_str()) + 1); + auto convert_obj = pkg.attr("AsyncDumpConverter")(pybind11::cast(files_to_convert_in_dir), dump_key); + (void)convert_obj.attr("convert_files")(); } catch (pybind11::error_already_set &e) { - MS_LOG(EXCEPTION) << "Can't find package mindspore.offline_debug.convert_async"; + MS_LOG(EXCEPTION) << "Failed to convert async dump data: " << e.what(); } std::string abspath = RealPath(dump_key); @@ -599,7 +594,7 @@ void DebugServices::ConvertToHostFormat(const std::mapd_type == DT_REG) { std::string candidate = dir->d_name; - for (const std::string &file_to_find : files_to_convert_in_dir) { + for (const std::string &file_to_find : files_after_convert_in_dir) { std::string file_n = file_to_find.substr(file_to_find.find_last_of("\\/") + 1); if (candidate.find(file_n) != std::string::npos && candidate.rfind(file_format) != std::string::npos) { // we found a converted file for this op diff --git a/mindspore/offline_debug/convert_async.py b/mindspore/offline_debug/convert_async.py index d77f78dff4c..0d137034c58 100644 --- a/mindspore/offline_debug/convert_async.py +++ b/mindspore/offline_debug/convert_async.py @@ -12,92 +12,229 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Module to provide conversion capabalities from .timestamp async dump files to .npy.""" -import site +""" +Module to provide conversion capabalities from .timestamp async dump files to .npy. +It's an internal module for debugger backend but not exposed to users. +""" import os -DIR_PATH = "/usr/local/Ascend/toolkit/tools/operator_cmp/compare/" -if not os.path.exists(DIR_PATH): - raise ValueError("Directory " + DIR_PATH + " does not exist. Please install Ascend toolkit.") -site.addsitedir(DIR_PATH) -#pylint: disable=wrong-import-position -import argparse -import csv -from dump_data_parser import DumpDataParser -from shape_conversion import FormatConversionMain -import utils -#pylint: enable=wrong-import-position +import glob +import stat +import sys +from pathlib import Path +from importlib import import_module +from collections import namedtuple + +import numpy as np -def handle_multi_process(convert_obj, files): - """Convert async format files to npy in a multithreaded manner""" - #pylint: disable=W0212 - return_code = utils.VECTOR_COMPARISON_NONE_ERROR - convert_obj.progress = utils.Progress(len(files)) - multi_process_file_list = [] - big_file_list = [] - max_file_size = convert_obj._get_max_file_size() - for cur_file in files: - cur_path = cur_file - if os.path.isfile(cur_path): - if os.path.getsize(cur_path) > max_file_size: - big_file_list.append(cur_path) +class ConvertToolLoader: + """Module to load CANN conversion tool.""" + + def __init__(self): + self.utils = None + self.common = None + self.dump_data_parser = None + self.format_conversion = None + self.load_convert_tool() + + @staticmethod + def find_toolkit_path(): + """Find the path to Ascend toolkit.""" + ascend_install_path = "/usr/local/Ascend" + if not os.path.exists(ascend_install_path): + ascend_toolkit_path = os.getenv("ASCEND_TOOLKIT_PATH") + if not ascend_toolkit_path: + raise ValueError( + "Failed to get $ASCEND_TOOLKIT_PATH in environment. Please install run packages " + + "and set the environment variable correctly.") + ascend_install_path = ascend_toolkit_path + ascend_install_path = Path(ascend_install_path).resolve() + msaccucmp_file_list = list(ascend_install_path.rglob('msaccucmp.py*')) + if not msaccucmp_file_list: + raise ValueError("Failed to find msaccucmp.py or msaccucmp.pyc file under " + + ascend_install_path + ". Please install Ascend toolkit.") + return msaccucmp_file_list[0].parent + + def load_convert_tool(self): + """load CANN conversion tool from the toolkit path.""" + toolkit_path = self.find_toolkit_path() + if str(toolkit_path) not in sys.path: + sys.path.append(str(toolkit_path)) + try: + self.utils = import_module('utils') + self.common = import_module('common') + self.dump_data_parser = import_module( + 'dump_data_parser').DumpDataParser + self.format_conversion = import_module( + 'shape_conversion').FormatConversionMain + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Failed to load CANN conversion tools under " + toolkit_path + ". Please make sure Ascend " + + "toolkit has been installed properly.") + + +def parse_args(file_list, output_path): + """Helper function to parse the input argument for the conversion configuration.""" + args_dict = dict() + args_dict['dump_version'] = '2.0' + args_dict['format'] = 'NCHW' + args_dict['output_file_type'] = 'npy' + args_dict['dump_path'] = output_path + args_dict['output_path'] = output_path + args_dict['file_list'] = file_list + args_dict['input'] = None + args_dict['output'] = None + args_dict['shape'] = None + args_dict['custom_script_path'] = None + args_parser = namedtuple("args_parser", args_dict.keys()) + return args_parser(**args_dict) + + +class AsyncDumpConverter: + """Convert the target async dump data into npy files.""" + + def __init__(self, file_list, output_path): + # check input path + for file_item in file_list: + file_item = os.path.realpath(file_item) + output_path = os.path.realpath(output_path) + + self.convert_tool = ConvertToolLoader() + self.args = parse_args(file_list, output_path) + self.files_to_convert = self.args.file_list + self.output_path = self.args.output_path + self.failed_file_path = os.path.join( + self.output_path, 'convert_failed_file_list.txt') + self.clear_failed_list_file() + + def clear_failed_list_file(self): + """Remove existing failed txt file.""" + if self.failed_file_path and os.path.exists(self.failed_file_path): + os.remove(self.failed_file_path) + + def convert_files(self): + """Main entry of the converter to convert async dump files into npy format.""" + self.convert_tool.utils.print_info_log('Start to convert async dump files.') + ret_code = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR + if self.args.format is not None: + convert = self.convert_tool.format_conversion(self.args) + else: + convert = self.convert_tool.dump_data_parser(self.args) + ret_code = self.handle_multi_process(convert, self.files_to_convert) + self._rename_generated_npy_files() + if ret_code != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR: + if os.path.exists(self.failed_file_path): + self.convert_failed_tensors() + self.convert_tool.utils.print_info_log('Finish to convert async dump files.') + + def convert_failed_tensors(self): + """Convert the failed tensor recorded in the failed txt file.""" + self.convert_tool.utils.print_info_log( + 'Start to convert failed tensors recorded in ' + self.failed_file_path + '.') + with open(self.failed_file_path) as failed_lines: + for failed_line in failed_lines: + try: + failed_line_list = failed_line.rstrip().split(',') + self.convert_one_failed_tensor(failed_line_list) + except (ValueError, OSError, AttributeError, self.convert_tool.utils.CompareError) as err: + self.convert_tool.utils.print_error_log( + 'Failed to convert ' + failed_line + ' to Host format: ' + str(err)) + + def convert_one_failed_tensor(self, failed_tensor): + """Convert failed operator one by one.""" + if len(failed_tensor) <= 1: + raise ValueError( + "Invalid tensor info in convert_failed_file_list.txt") + file_path = failed_tensor[0] + type_index = failed_tensor[1:] + op_data = self.convert_tool.utils.parse_dump_file( + file_path, self.args.dump_version) + for type_index_item in type_index: + tensor_type, index = type_index_item.split(':') + index = int(index) + tensor = getattr(op_data, tensor_type)[index] + dump_data_array = self.convert_tool.utils.deserialize_dump_data_to_array(tensor) + array = dump_data_array.reshape(tensor.shape.dim) + self._save_tensor_to_npy_file( + file_path, tensor_type, index, tensor.format, array) + + def handle_multi_process(self, convert_obj, files): + """Convert async format files to npy in a multithreaded manner.""" + return_code = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR + # try looking for function in compatibility with the toolkit package version. + if hasattr(convert_obj, 'multi_process'): + _ = setattr(convert_obj.multi_process, '_progress', self.convert_tool.utils.Progress(len(files))) + else: + _ = setattr(convert_obj, 'progress', self.convert_tool.utils.Progress(len(files))) + multi_process_file_list = [] + big_file_list = [] + max_file_size = 0 + if hasattr(convert_obj, 'multi_process'): + max_file_size = getattr(convert_obj.multi_process, 'get_max_file_size')() + else: + max_file_size = getattr(convert_obj, '_get_max_file_size')() + for cur_file in files: + cur_path = cur_file + if os.path.isfile(cur_path): + if os.path.getsize(cur_path) > max_file_size: + big_file_list.append(cur_path) + else: + multi_process_file_list.append(cur_path) + if multi_process_file_list: + ret_mp = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR + if hasattr(convert_obj, 'multi_process'): + ret_mp = getattr(convert_obj.multi_process, '_do_multi_process')(multi_process_file_list) else: - multi_process_file_list.append(cur_path) + ret_mp = getattr(convert_obj, '_do_multi_process')(multi_process_file_list) + if ret_mp != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR: + return_code = ret_mp + for big_file in big_file_list: + ret_bf = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR + if hasattr(convert_obj, '_convert_format_for_one_file'): + ret_bf, _ = getattr(convert_obj, '_convert_format_for_one_file')(big_file) + else: + ret_bf, _ = getattr(convert_obj, 'convert_format_for_one_file')(big_file) + if hasattr(convert_obj, 'multi_process'): + getattr(convert_obj.multi_process, '_handle_result_callback')([ret_bf, big_file]) + else: + getattr(convert_obj, '_handle_result_callback')([ret_bf, big_file]) + if ret_bf != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR: + return_code = ret_bf + if return_code != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR: + if os.path.exists(self.failed_file_path): + self.convert_tool.utils.print_info_log( + 'The list of file that failed to convert has been written to "' + + self.failed_file_path + '".') + return return_code - if multi_process_file_list: - ret = convert_obj._do_multi_process(multi_process_file_list) - if ret != utils.VECTOR_COMPARISON_NONE_ERROR: - return_code = ret - for big_file in big_file_list: - ret, _ = convert_obj.convert_format_for_one_file(big_file) - convert_obj._handle_result_callback([ret, big_file]) - if ret != utils.VECTOR_COMPARISON_NONE_ERROR: - return_code = ret + def _save_tensor_to_npy_file(self, file_path, tensor_type, idx, tensor_format, dump_data_array): + """Save tensor file into npy format.""" + file_name = os.path.basename(file_path) + name_splits = file_name.split('.') + name_splits[1] = name_splits[1].split('_')[-1] + file_name_no_scope = '.'.join(name_splits) + out_file_name = "%s.%s.%d.%s.npy" % ( + file_name_no_scope, + tensor_type, + idx, + self.convert_tool.common.get_format_string(tensor_format) + ) + out_path = os.path.join(self.output_path, out_file_name) + np.save(out_path, dump_data_array) + os.chmod(out_path, stat.S_IRUSR) - if return_code != utils.VECTOR_COMPARISON_NONE_ERROR: - error_file_path = os.path.join( - convert_obj.output_path, utils.CONVERT_FAILED_FILE_LIST_NAME) - if os.path.exists(error_file_path): - utils.print_info_log( - 'The list of file that failed to convert has been written to "' + error_file_path + '".') - # pylint: enable=W0212 - return return_code - -if __name__ == "__main__": - convert_parser = argparse.ArgumentParser() - convert_parser.add_argument( - '-d', '--dump_file', dest='dump_path', default='', required=True) - convert_parser.add_argument( - '-l', '--file_list', nargs="*", dest='file_list', default='') - convert_parser.add_argument('-f', '--format', dest='format', default=None) - convert_parser.add_argument( - '-v', '--version', dest='dump_version', choices=[1, 2], type=int, default=2) - convert_parser.add_argument('-s', '--shape', dest='shape', default=None) - convert_parser.add_argument('-o', '--output_tensor', - dest='output', default=None) - convert_parser.add_argument('-i', '--input_tensor', dest='input', default=None) - convert_parser.add_argument( - '-c', '--custom_script_path', dest='custom_script_path', default=None) - convert_parser.add_argument('-out', '--output', dest='output_path', default='') - convert_parser.add_argument( - '-t', '--type', dest='output_file_type', choices=['npy', 'bin'], default='npy') - - args = convert_parser.parse_args() - dump_failed = os.path.abspath(args.dump_path) + "/convert_failed_file_list.txt" - if os.path.exists(dump_failed): - os.remove(dump_failed) - file_list = args.file_list - if args.format is not None: - convert = FormatConversionMain(args) - else: - convert = DumpDataParser(args) - if args.file_list == "": - file_list = os.listdir(args.dump_path) - handle_multi_process(convert, file_list) - if os.path.exists(dump_failed): - with open(dump_failed, newline='') as failed_ops: - file_reader = csv.reader(failed_ops, delimiter=',') - file_list = [os.path.abspath(row[0]) for row in file_reader] - args.format = None - convert = DumpDataParser(args) - handle_multi_process(convert, file_list) + def _rename_generated_npy_files(self): + """In order to follow dump naming convention, rename npy files generated by CANN conversion tool.""" + target_file_list = [] + for in_file in self.files_to_convert: + target_file_list.extend(glob.glob(in_file + "*.npy")) + for target_file in target_file_list: + old_filename = os.path.basename(target_file) + name_splits = old_filename.split('.') + name_splits[1] = name_splits[1].split('_')[-1] + name_splits[-2] = self.args.format + new_file_name = '.'.join(name_splits) + out_path = os.path.join(self.output_path, new_file_name) + os.rename(target_file, out_path) + os.chmod(out_path, stat.S_IRUSR) + self.convert_tool.utils.print_info_log("Rename file " + target_file + " to " + out_path) diff --git a/tests/st/dump/async_dump.json b/tests/st/dump/async_dump.json deleted file mode 100644 index e629ef2564a..00000000000 --- a/tests/st/dump/async_dump.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "common_dump_settings": { - "dump_mode": 0, - "path": "/test", - "net_name": "Net", - "iteration": "0", - "input_output": 2, - "kernels": ["Default/TensorAdd-op3"], - "support_device": [0,1,2,3,4,5,6,7], - "op_debug_mode": 0 - } -} \ No newline at end of file diff --git a/tests/st/dump/dump_test_utils.py b/tests/st/dump/dump_test_utils.py new file mode 100644 index 00000000000..a6e51c73ced --- /dev/null +++ b/tests/st/dump/dump_test_utils.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Utils for testing dump feature. +""" + +import json + +async_dump_dict = { + "common_dump_settings": { + "dump_mode": 0, + "path": "", + "net_name": "Net", + "iteration": "0", + "input_output": 2, + "kernels": ["Default/TensorAdd-op3"], + "support_device": [0, 1, 2, 3, 4, 5, 6, 7], + "op_debug_mode": 0 + } +} + +e2e_dump_dict = { + "common_dump_settings": { + "dump_mode": 0, + "path": "", + "net_name": "Net", + "iteration": "0", + "input_output": 0, + "kernels": ["Default/Conv-op12"], + "support_device": [0, 1, 2, 3, 4, 5, 6, 7], + "op_debug_mode": 0 + }, + "e2e_dump_settings": { + "enable": True, + "trans_flag": False + } +} + +async_dump_dict_2 = { + "common_dump_settings": { + "dump_mode": 0, + "path": "/tmp/async_dump/test_async_dump_net_multi_layer_mode1", + "net_name": "test", + "iteration": "0", + "input_output": 2, + "kernels": [ + "default/TensorAdd-op10", + "Gradients/Default/network-WithLossCell/_backbone-ReLUReduceMeanDenseRelu/dense-Dense/gradBiasAdd/"\ + "BiasAddGrad-op8", + "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/SoftmaxCrossEntropyWithLogits-op5", + "Default/optimizer-Momentum/tuple_getitem-op29", + "Default/optimizer-Momentum/ApplyMomentum-op12" + ], + "support_device": [0, 1, 2, 3, 4, 5, 6, 7], + "op_debug_mode": 0 + } +} + + +def generate_dump_json(dump_path, json_file_name, test_key): + """ + Util function to generate dump configuration json file. + """ + data = dict() + if test_key == "test_async_dump": + data = async_dump_dict + data["common_dump_settings"]["path"] = dump_path + elif test_key == "test_e2e_dump": + data = e2e_dump_dict + data["common_dump_settings"]["path"] = dump_path + elif test_key == "test_async_dump_net_multi_layer_mode1": + data = async_dump_dict_2 + data["common_dump_settings"]["path"] = dump_path + else: + raise ValueError( + "Failed to generate dump json file. The test name value " + test_key + " is invalid.") + with open(json_file_name, 'w') as f: + json.dump(data, f) diff --git a/tests/st/dump/e2e_dump.json b/tests/st/dump/e2e_dump.json deleted file mode 100644 index 73a0b6c96de..00000000000 --- a/tests/st/dump/e2e_dump.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "common_dump_settings": { - "dump_mode": 0, - "path": "/test", - "net_name": "Net", - "iteration": "0", - "input_output": 0, - "kernels": ["Default/Conv-op12"], - "support_device": [0,1,2,3,4,5,6,7], - "op_debug_mode": 0 - }, - "e2e_dump_settings": { - "enable": true, - "trans_flag": false - } -} \ No newline at end of file diff --git a/tests/st/dump/test_async_dump_net_multi_layer_mode1.json b/tests/st/dump/test_async_dump_net_multi_layer_mode1.json deleted file mode 100644 index 6ce51fc1a48..00000000000 --- a/tests/st/dump/test_async_dump_net_multi_layer_mode1.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "common_dump_settings":{ - "dump_mode": 0, - "path": "/tmp/async_dump/test_async_dump_net_multi_layer_mode1", - "net_name": "test", - "iteration": "0", - "input_output": 2, - "kernels": [ - "default/TensorAdd-op10", - "Gradients/Default/network-WithLossCell/_backbone-ReLUReduceMeanDenseRelu/dense-Dense/gradBiasAdd/BiasAddGrad-op8", - "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/SoftmaxCrossEntropyWithLogits-op5", - "Default/optimizer-Momentum/tuple_getitem-op29", - "Default/optimizer-Momentum/ApplyMomentum-op12" - ], - "support_device": [0,1,2,3,4,5,6,7], - "op_debug_mode": 0 - } -} \ No newline at end of file diff --git a/tests/st/dump/test_data_dump.py b/tests/st/dump/test_data_dump.py index 3ff753efdb6..f82a38288e7 100644 --- a/tests/st/dump/test_data_dump.py +++ b/tests/st/dump/test_data_dump.py @@ -13,13 +13,13 @@ # limitations under the License. # ============================================================================ import os -import json import sys import tempfile import time import shutil import glob - +from importlib import import_module +from pathlib import Path import numpy as np import pytest import mindspore.context as context @@ -32,6 +32,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.nn import Momentum from mindspore.nn import TrainOneStepCell from mindspore.nn import WithLossCell +from dump_test_utils import generate_dump_json class Net(nn.Cell): @@ -47,14 +48,6 @@ x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) y = np.array([[7, 8, 9], [10, 11, 12]]).astype(np.float32) -def change_current_dump_json(file_name, dump_path, dump_config_path): - with open(file_name, 'r+') as f: - data = json.load(f) - data["common_dump_settings"]["path"] = dump_path - with open(dump_config_path, 'w') as f: - json.dump(data, f) - - @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -65,7 +58,7 @@ def test_async_dump(): with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: dump_path = os.path.join(tmp_dir, 'async_dump') dump_config_path = os.path.join(tmp_dir, 'async_dump.json') - change_current_dump_json('async_dump.json', dump_path, dump_config_path) + generate_dump_json(dump_path, dump_config_path, 'test_async_dump') os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0') if os.path.isdir(dump_path): @@ -83,7 +76,7 @@ def run_e2e_dump(): with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: dump_path = os.path.join(tmp_dir, 'e2e_dump') dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json') - change_current_dump_json('e2e_dump.json', dump_path, dump_config_path) + generate_dump_json(dump_path, dump_config_path, 'test_e2e_dump') os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0') if os.path.isdir(dump_path): @@ -178,69 +171,47 @@ class ReluReduceMeanDenseRelu(Cell): return x_ -def search_path(path, keyword): - content = os.listdir(path) - for each in content: - each_path = path + os.sep + each - if keyword in each: - return each_path - read_write = os.access(each_path, os.W_OK) and os.access(each_path, os.R_OK) - if not read_write: - continue - if os.path.isdir(each_path): - search_path(each_path, keyword) - return None - - @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_async_dump_net_multi_layer_mode1(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - test_name = "test_async_dump_net_multi_layer_mode1" - json_file = os.path.join(os.getcwd(), "{}.json".format(test_name)) - rank_id = 0 - dump_full_path = os.path.join("/tmp/async_dump/", "{}_{}".format(test_name, rank_id)) - os.system("rm -rf {}/*".format(dump_full_path)) - os.environ["MINDSPORE_DUMP_CONFIG"] = json_file - weight = Tensor(np.ones((1000, 2048)).astype(np.float32)) - bias = Tensor(np.ones((1000,)).astype(np.float32)) - net = ReluReduceMeanDenseRelu(weight, bias, 2048, 1000) - criterion = SoftmaxCrossEntropyWithLogits(sparse=False) - optimizer = Momentum(learning_rate=0.1, momentum=0.1, - params=filter(lambda x: x.requires_grad, net.get_parameters())) - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) - train_network.set_train() - inputs = Tensor(np.random.randn(32, 2048, 7, 7).astype(np.float32)) - label = Tensor(np.zeros(shape=(32, 1000)).astype(np.float32)) - net_dict = train_network(inputs, label) - - dump_path = "/tmp/async_dump/{}/rank_{}/test/0/0/".format(test_name, rank_id) - dump_file = os.listdir(dump_path) - dump_file_name = "" - for file in dump_file: - if "SoftmaxCrossEntropyWithLogits" in file: - dump_file_name = file - dump_file_full_path = os.path.join(dump_path, dump_file_name) - npy_path = os.path.join(os.getcwd(), "./{}".format(test_name)) - if os.path.exists(npy_path): - shutil.rmtree(npy_path) - os.mkdir(npy_path) - tool_path = search_path('/usr/local/Ascend', 'msaccucmp.pyc') - if tool_path: - cmd = "python {0} convert -d {1} -out {2}".format(tool_path, dump_file_full_path, npy_path) - os.system(cmd) - npy_file_list = os.listdir(npy_path) - dump_result = {} - for file in npy_file_list: - if "output.0.npy" in file: - dump_result["output0"] = np.load(os.path.join(npy_path, file)) - for index, value in enumerate(net_dict): - assert value.asnumpy() == dump_result["output0"][index] - else: - print('not find convert tools msaccucmp.pyc') + pwd = os.getcwd() + with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: + dump_path = os.path.join(tmp_dir, 'async_dump_net_multi_layer_mode1') + json_file_path = os.path.join(tmp_dir, "test_async_dump_net_multi_layer_mode1.json") + generate_dump_json(dump_path, json_file_path, 'test_async_dump_net_multi_layer_mode1') + os.environ['MINDSPORE_DUMP_CONFIG'] = json_file_path + weight = Tensor(np.ones((1000, 2048)).astype(np.float32)) + bias = Tensor(np.ones((1000,)).astype(np.float32)) + net = ReluReduceMeanDenseRelu(weight, bias, 2048, 1000) + criterion = SoftmaxCrossEntropyWithLogits(sparse=False) + optimizer = Momentum(learning_rate=0.1, momentum=0.1, + params=filter(lambda x: x.requires_grad, net.get_parameters())) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + inputs = Tensor(np.random.randn(32, 2048, 7, 7).astype(np.float32)) + label = Tensor(np.zeros(shape=(32, 1000)).astype(np.float32)) + net_dict = train_network(inputs, label) + dump_file_path = os.path.join(dump_path, 'rank_0', 'test', '0', '0') + dump_file_name = list(Path(dump_file_path).rglob("*SoftmaxCrossEntropyWithLogits*"))[0] + dump_file_full_path = os.path.join(dump_file_path, dump_file_name) + npy_path = os.path.join(dump_path, "npy_files") + if os.path.exists(npy_path): + shutil.rmtree(npy_path) + os.mkdir(npy_path) + tool_path_search_list = list(Path('/usr/local/Ascend').rglob('msaccucmp.py*')) + if tool_path_search_list: + converter = import_module("mindspore.offline_debug.convert_async") + converter.AsyncDumpConverter([dump_file_full_path], npy_path).convert_files() + npy_result_file = list(Path(npy_path).rglob("*output.0.*.npy"))[0] + dump_result = np.load(os.path.join(npy_path, npy_result_file)) + for index, value in enumerate(net_dict): + assert value.asnumpy() == dump_result[index] + else: + print('Failed to find hisi convert tools: msaccucmp.py or msaccucmp.pyc.') @pytest.mark.level0 @@ -256,7 +227,7 @@ def test_dump_with_diagnostic_path(): pwd = os.getcwd() with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json') - change_current_dump_json('e2e_dump.json', '', dump_config_path) + generate_dump_json('', dump_config_path, 'test_e2e_dump') os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path diagnose_path = os.path.join(tmp_dir, 'e2e_dump') os.environ['MS_DIAGNOSTIC_DATA_PATH'] = diagnose_path