Fix convert async dump files failed issue and refactor convert_async.py

This commit is contained in:
TinaMengtingZhang 2021-07-27 21:25:09 -04:00
parent 5a851daf2f
commit b17b2bc687
7 changed files with 360 additions and 213 deletions

View File

@ -26,6 +26,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "pybind11/embed.h" #include "pybind11/embed.h"
#include "pybind11/stl.h"
#ifdef ONLINE_DBG_MODE #ifdef ONLINE_DBG_MODE
#include "debug/common.h" #include "debug/common.h"
#include "debug/debugger/debugger.h" #include "debug/debugger/debugger.h"
@ -549,6 +550,7 @@ void DebugServices::ConvertToHostFormat(const std::map<std::string, std::vector<
std::string file_format = "npy"; std::string file_format = "npy";
for (auto const &d : dir_to_files_map) { for (auto const &d : dir_to_files_map) {
std::vector<std::string> files_to_convert_in_dir; std::vector<std::string> files_to_convert_in_dir;
std::vector<std::string> files_after_convert_in_dir;
std::string dump_key = d.first; std::string dump_key = d.first;
for (auto const &file_name : d.second) { for (auto const &file_name : d.second) {
bool already_converted = false; bool already_converted = false;
@ -567,26 +569,19 @@ void DebugServices::ConvertToHostFormat(const std::map<std::string, std::vector<
} }
if (!already_converted) { if (!already_converted) {
files_to_convert_in_dir.push_back(dump_key + "/" + file_name); files_to_convert_in_dir.push_back(dump_key + "/" + file_name);
files_after_convert_in_dir.push_back(dump_key + "/" + file_name_without_scope);
} }
} }
std::ostringstream input_file_o; MS_LOG(INFO) << "Number of files to convert: " << files_to_convert_in_dir.size();
const char *const delim = " "; if (!files_to_convert_in_dir.empty()) {
std::copy(files_to_convert_in_dir.begin(), files_to_convert_in_dir.end(),
std::ostream_iterator<std::string>(input_file_o, delim));
std::string input_files = input_file_o.str();
MS_LOG(INFO) << "Ops to convert: " << input_files;
if (input_files != "") {
// Look for the installation path to the conver_async package. If not found, throw exception and terminate the // Look for the installation path to the conver_async package. If not found, throw exception and terminate the
// later task. // later task.
try { try {
auto pkg = pybind11::module::import("mindspore.offline_debug.convert_async"); auto pkg = pybind11::module::import("mindspore.offline_debug.convert_async");
std::string convert_pkg_path = pkg.attr("__file__").cast<std::string>(); auto convert_obj = pkg.attr("AsyncDumpConverter")(pybind11::cast(files_to_convert_in_dir), dump_key);
MS_LOG(INFO) << "The file for converting async dump data is in " << convert_pkg_path; (void)convert_obj.attr("convert_files")();
std::string convert_command = "python " + convert_pkg_path + " -out " + dump_key + " -t " + file_format +
" -d " + dump_key + " -f NCHW -l " + input_files;
(void)(system(convert_command.c_str()) + 1);
} catch (pybind11::error_already_set &e) { } catch (pybind11::error_already_set &e) {
MS_LOG(EXCEPTION) << "Can't find package mindspore.offline_debug.convert_async"; MS_LOG(EXCEPTION) << "Failed to convert async dump data: " << e.what();
} }
std::string abspath = RealPath(dump_key); std::string abspath = RealPath(dump_key);
@ -599,7 +594,7 @@ void DebugServices::ConvertToHostFormat(const std::map<std::string, std::vector<
while ((dir = readdir(d_handle)) != nullptr) { while ((dir = readdir(d_handle)) != nullptr) {
if (dir->d_type == DT_REG) { if (dir->d_type == DT_REG) {
std::string candidate = dir->d_name; std::string candidate = dir->d_name;
for (const std::string &file_to_find : files_to_convert_in_dir) { for (const std::string &file_to_find : files_after_convert_in_dir) {
std::string file_n = file_to_find.substr(file_to_find.find_last_of("\\/") + 1); std::string file_n = file_to_find.substr(file_to_find.find_last_of("\\/") + 1);
if (candidate.find(file_n) != std::string::npos && candidate.rfind(file_format) != std::string::npos) { if (candidate.find(file_n) != std::string::npos && candidate.rfind(file_format) != std::string::npos) {
// we found a converted file for this op // we found a converted file for this op

View File

@ -12,92 +12,229 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Module to provide conversion capabalities from .timestamp async dump files to .npy.""" """
import site Module to provide conversion capabalities from .timestamp async dump files to .npy.
It's an internal module for debugger backend but not exposed to users.
"""
import os import os
DIR_PATH = "/usr/local/Ascend/toolkit/tools/operator_cmp/compare/" import glob
if not os.path.exists(DIR_PATH): import stat
raise ValueError("Directory " + DIR_PATH + " does not exist. Please install Ascend toolkit.") import sys
site.addsitedir(DIR_PATH) from pathlib import Path
#pylint: disable=wrong-import-position from importlib import import_module
import argparse from collections import namedtuple
import csv
from dump_data_parser import DumpDataParser import numpy as np
from shape_conversion import FormatConversionMain
import utils
#pylint: enable=wrong-import-position
def handle_multi_process(convert_obj, files): class ConvertToolLoader:
"""Convert async format files to npy in a multithreaded manner""" """Module to load CANN conversion tool."""
#pylint: disable=W0212
return_code = utils.VECTOR_COMPARISON_NONE_ERROR def __init__(self):
convert_obj.progress = utils.Progress(len(files)) self.utils = None
multi_process_file_list = [] self.common = None
big_file_list = [] self.dump_data_parser = None
max_file_size = convert_obj._get_max_file_size() self.format_conversion = None
for cur_file in files: self.load_convert_tool()
cur_path = cur_file
if os.path.isfile(cur_path): @staticmethod
if os.path.getsize(cur_path) > max_file_size: def find_toolkit_path():
big_file_list.append(cur_path) """Find the path to Ascend toolkit."""
ascend_install_path = "/usr/local/Ascend"
if not os.path.exists(ascend_install_path):
ascend_toolkit_path = os.getenv("ASCEND_TOOLKIT_PATH")
if not ascend_toolkit_path:
raise ValueError(
"Failed to get $ASCEND_TOOLKIT_PATH in environment. Please install run packages " +
"and set the environment variable correctly.")
ascend_install_path = ascend_toolkit_path
ascend_install_path = Path(ascend_install_path).resolve()
msaccucmp_file_list = list(ascend_install_path.rglob('msaccucmp.py*'))
if not msaccucmp_file_list:
raise ValueError("Failed to find msaccucmp.py or msaccucmp.pyc file under " +
ascend_install_path + ". Please install Ascend toolkit.")
return msaccucmp_file_list[0].parent
def load_convert_tool(self):
"""load CANN conversion tool from the toolkit path."""
toolkit_path = self.find_toolkit_path()
if str(toolkit_path) not in sys.path:
sys.path.append(str(toolkit_path))
try:
self.utils = import_module('utils')
self.common = import_module('common')
self.dump_data_parser = import_module(
'dump_data_parser').DumpDataParser
self.format_conversion = import_module(
'shape_conversion').FormatConversionMain
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Failed to load CANN conversion tools under " + toolkit_path + ". Please make sure Ascend " +
"toolkit has been installed properly.")
def parse_args(file_list, output_path):
"""Helper function to parse the input argument for the conversion configuration."""
args_dict = dict()
args_dict['dump_version'] = '2.0'
args_dict['format'] = 'NCHW'
args_dict['output_file_type'] = 'npy'
args_dict['dump_path'] = output_path
args_dict['output_path'] = output_path
args_dict['file_list'] = file_list
args_dict['input'] = None
args_dict['output'] = None
args_dict['shape'] = None
args_dict['custom_script_path'] = None
args_parser = namedtuple("args_parser", args_dict.keys())
return args_parser(**args_dict)
class AsyncDumpConverter:
"""Convert the target async dump data into npy files."""
def __init__(self, file_list, output_path):
# check input path
for file_item in file_list:
file_item = os.path.realpath(file_item)
output_path = os.path.realpath(output_path)
self.convert_tool = ConvertToolLoader()
self.args = parse_args(file_list, output_path)
self.files_to_convert = self.args.file_list
self.output_path = self.args.output_path
self.failed_file_path = os.path.join(
self.output_path, 'convert_failed_file_list.txt')
self.clear_failed_list_file()
def clear_failed_list_file(self):
"""Remove existing failed txt file."""
if self.failed_file_path and os.path.exists(self.failed_file_path):
os.remove(self.failed_file_path)
def convert_files(self):
"""Main entry of the converter to convert async dump files into npy format."""
self.convert_tool.utils.print_info_log('Start to convert async dump files.')
ret_code = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR
if self.args.format is not None:
convert = self.convert_tool.format_conversion(self.args)
else:
convert = self.convert_tool.dump_data_parser(self.args)
ret_code = self.handle_multi_process(convert, self.files_to_convert)
self._rename_generated_npy_files()
if ret_code != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR:
if os.path.exists(self.failed_file_path):
self.convert_failed_tensors()
self.convert_tool.utils.print_info_log('Finish to convert async dump files.')
def convert_failed_tensors(self):
"""Convert the failed tensor recorded in the failed txt file."""
self.convert_tool.utils.print_info_log(
'Start to convert failed tensors recorded in ' + self.failed_file_path + '.')
with open(self.failed_file_path) as failed_lines:
for failed_line in failed_lines:
try:
failed_line_list = failed_line.rstrip().split(',')
self.convert_one_failed_tensor(failed_line_list)
except (ValueError, OSError, AttributeError, self.convert_tool.utils.CompareError) as err:
self.convert_tool.utils.print_error_log(
'Failed to convert ' + failed_line + ' to Host format: ' + str(err))
def convert_one_failed_tensor(self, failed_tensor):
"""Convert failed operator one by one."""
if len(failed_tensor) <= 1:
raise ValueError(
"Invalid tensor info in convert_failed_file_list.txt")
file_path = failed_tensor[0]
type_index = failed_tensor[1:]
op_data = self.convert_tool.utils.parse_dump_file(
file_path, self.args.dump_version)
for type_index_item in type_index:
tensor_type, index = type_index_item.split(':')
index = int(index)
tensor = getattr(op_data, tensor_type)[index]
dump_data_array = self.convert_tool.utils.deserialize_dump_data_to_array(tensor)
array = dump_data_array.reshape(tensor.shape.dim)
self._save_tensor_to_npy_file(
file_path, tensor_type, index, tensor.format, array)
def handle_multi_process(self, convert_obj, files):
"""Convert async format files to npy in a multithreaded manner."""
return_code = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR
# try looking for function in compatibility with the toolkit package version.
if hasattr(convert_obj, 'multi_process'):
_ = setattr(convert_obj.multi_process, '_progress', self.convert_tool.utils.Progress(len(files)))
else:
_ = setattr(convert_obj, 'progress', self.convert_tool.utils.Progress(len(files)))
multi_process_file_list = []
big_file_list = []
max_file_size = 0
if hasattr(convert_obj, 'multi_process'):
max_file_size = getattr(convert_obj.multi_process, 'get_max_file_size')()
else:
max_file_size = getattr(convert_obj, '_get_max_file_size')()
for cur_file in files:
cur_path = cur_file
if os.path.isfile(cur_path):
if os.path.getsize(cur_path) > max_file_size:
big_file_list.append(cur_path)
else:
multi_process_file_list.append(cur_path)
if multi_process_file_list:
ret_mp = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR
if hasattr(convert_obj, 'multi_process'):
ret_mp = getattr(convert_obj.multi_process, '_do_multi_process')(multi_process_file_list)
else: else:
multi_process_file_list.append(cur_path) ret_mp = getattr(convert_obj, '_do_multi_process')(multi_process_file_list)
if ret_mp != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR:
return_code = ret_mp
for big_file in big_file_list:
ret_bf = self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR
if hasattr(convert_obj, '_convert_format_for_one_file'):
ret_bf, _ = getattr(convert_obj, '_convert_format_for_one_file')(big_file)
else:
ret_bf, _ = getattr(convert_obj, 'convert_format_for_one_file')(big_file)
if hasattr(convert_obj, 'multi_process'):
getattr(convert_obj.multi_process, '_handle_result_callback')([ret_bf, big_file])
else:
getattr(convert_obj, '_handle_result_callback')([ret_bf, big_file])
if ret_bf != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR:
return_code = ret_bf
if return_code != self.convert_tool.utils.VECTOR_COMPARISON_NONE_ERROR:
if os.path.exists(self.failed_file_path):
self.convert_tool.utils.print_info_log(
'The list of file that failed to convert has been written to "'
+ self.failed_file_path + '".')
return return_code
if multi_process_file_list: def _save_tensor_to_npy_file(self, file_path, tensor_type, idx, tensor_format, dump_data_array):
ret = convert_obj._do_multi_process(multi_process_file_list) """Save tensor file into npy format."""
if ret != utils.VECTOR_COMPARISON_NONE_ERROR: file_name = os.path.basename(file_path)
return_code = ret name_splits = file_name.split('.')
for big_file in big_file_list: name_splits[1] = name_splits[1].split('_')[-1]
ret, _ = convert_obj.convert_format_for_one_file(big_file) file_name_no_scope = '.'.join(name_splits)
convert_obj._handle_result_callback([ret, big_file]) out_file_name = "%s.%s.%d.%s.npy" % (
if ret != utils.VECTOR_COMPARISON_NONE_ERROR: file_name_no_scope,
return_code = ret tensor_type,
idx,
self.convert_tool.common.get_format_string(tensor_format)
)
out_path = os.path.join(self.output_path, out_file_name)
np.save(out_path, dump_data_array)
os.chmod(out_path, stat.S_IRUSR)
if return_code != utils.VECTOR_COMPARISON_NONE_ERROR: def _rename_generated_npy_files(self):
error_file_path = os.path.join( """In order to follow dump naming convention, rename npy files generated by CANN conversion tool."""
convert_obj.output_path, utils.CONVERT_FAILED_FILE_LIST_NAME) target_file_list = []
if os.path.exists(error_file_path): for in_file in self.files_to_convert:
utils.print_info_log( target_file_list.extend(glob.glob(in_file + "*.npy"))
'The list of file that failed to convert has been written to "' + error_file_path + '".') for target_file in target_file_list:
# pylint: enable=W0212 old_filename = os.path.basename(target_file)
return return_code name_splits = old_filename.split('.')
name_splits[1] = name_splits[1].split('_')[-1]
if __name__ == "__main__": name_splits[-2] = self.args.format
convert_parser = argparse.ArgumentParser() new_file_name = '.'.join(name_splits)
convert_parser.add_argument( out_path = os.path.join(self.output_path, new_file_name)
'-d', '--dump_file', dest='dump_path', default='', required=True) os.rename(target_file, out_path)
convert_parser.add_argument( os.chmod(out_path, stat.S_IRUSR)
'-l', '--file_list', nargs="*", dest='file_list', default='') self.convert_tool.utils.print_info_log("Rename file " + target_file + " to " + out_path)
convert_parser.add_argument('-f', '--format', dest='format', default=None)
convert_parser.add_argument(
'-v', '--version', dest='dump_version', choices=[1, 2], type=int, default=2)
convert_parser.add_argument('-s', '--shape', dest='shape', default=None)
convert_parser.add_argument('-o', '--output_tensor',
dest='output', default=None)
convert_parser.add_argument('-i', '--input_tensor', dest='input', default=None)
convert_parser.add_argument(
'-c', '--custom_script_path', dest='custom_script_path', default=None)
convert_parser.add_argument('-out', '--output', dest='output_path', default='')
convert_parser.add_argument(
'-t', '--type', dest='output_file_type', choices=['npy', 'bin'], default='npy')
args = convert_parser.parse_args()
dump_failed = os.path.abspath(args.dump_path) + "/convert_failed_file_list.txt"
if os.path.exists(dump_failed):
os.remove(dump_failed)
file_list = args.file_list
if args.format is not None:
convert = FormatConversionMain(args)
else:
convert = DumpDataParser(args)
if args.file_list == "":
file_list = os.listdir(args.dump_path)
handle_multi_process(convert, file_list)
if os.path.exists(dump_failed):
with open(dump_failed, newline='') as failed_ops:
file_reader = csv.reader(failed_ops, delimiter=',')
file_list = [os.path.abspath(row[0]) for row in file_reader]
args.format = None
convert = DumpDataParser(args)
handle_multi_process(convert, file_list)

View File

@ -1,12 +0,0 @@
{
"common_dump_settings": {
"dump_mode": 0,
"path": "/test",
"net_name": "Net",
"iteration": "0",
"input_output": 2,
"kernels": ["Default/TensorAdd-op3"],
"support_device": [0,1,2,3,4,5,6,7],
"op_debug_mode": 0
}
}

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Utils for testing dump feature.
"""
import json
async_dump_dict = {
"common_dump_settings": {
"dump_mode": 0,
"path": "",
"net_name": "Net",
"iteration": "0",
"input_output": 2,
"kernels": ["Default/TensorAdd-op3"],
"support_device": [0, 1, 2, 3, 4, 5, 6, 7],
"op_debug_mode": 0
}
}
e2e_dump_dict = {
"common_dump_settings": {
"dump_mode": 0,
"path": "",
"net_name": "Net",
"iteration": "0",
"input_output": 0,
"kernels": ["Default/Conv-op12"],
"support_device": [0, 1, 2, 3, 4, 5, 6, 7],
"op_debug_mode": 0
},
"e2e_dump_settings": {
"enable": True,
"trans_flag": False
}
}
async_dump_dict_2 = {
"common_dump_settings": {
"dump_mode": 0,
"path": "/tmp/async_dump/test_async_dump_net_multi_layer_mode1",
"net_name": "test",
"iteration": "0",
"input_output": 2,
"kernels": [
"default/TensorAdd-op10",
"Gradients/Default/network-WithLossCell/_backbone-ReLUReduceMeanDenseRelu/dense-Dense/gradBiasAdd/"\
"BiasAddGrad-op8",
"Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/SoftmaxCrossEntropyWithLogits-op5",
"Default/optimizer-Momentum/tuple_getitem-op29",
"Default/optimizer-Momentum/ApplyMomentum-op12"
],
"support_device": [0, 1, 2, 3, 4, 5, 6, 7],
"op_debug_mode": 0
}
}
def generate_dump_json(dump_path, json_file_name, test_key):
"""
Util function to generate dump configuration json file.
"""
data = dict()
if test_key == "test_async_dump":
data = async_dump_dict
data["common_dump_settings"]["path"] = dump_path
elif test_key == "test_e2e_dump":
data = e2e_dump_dict
data["common_dump_settings"]["path"] = dump_path
elif test_key == "test_async_dump_net_multi_layer_mode1":
data = async_dump_dict_2
data["common_dump_settings"]["path"] = dump_path
else:
raise ValueError(
"Failed to generate dump json file. The test name value " + test_key + " is invalid.")
with open(json_file_name, 'w') as f:
json.dump(data, f)

View File

@ -1,16 +0,0 @@
{
"common_dump_settings": {
"dump_mode": 0,
"path": "/test",
"net_name": "Net",
"iteration": "0",
"input_output": 0,
"kernels": ["Default/Conv-op12"],
"support_device": [0,1,2,3,4,5,6,7],
"op_debug_mode": 0
},
"e2e_dump_settings": {
"enable": true,
"trans_flag": false
}
}

View File

@ -1,18 +0,0 @@
{
"common_dump_settings":{
"dump_mode": 0,
"path": "/tmp/async_dump/test_async_dump_net_multi_layer_mode1",
"net_name": "test",
"iteration": "0",
"input_output": 2,
"kernels": [
"default/TensorAdd-op10",
"Gradients/Default/network-WithLossCell/_backbone-ReLUReduceMeanDenseRelu/dense-Dense/gradBiasAdd/BiasAddGrad-op8",
"Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/SoftmaxCrossEntropyWithLogits-op5",
"Default/optimizer-Momentum/tuple_getitem-op29",
"Default/optimizer-Momentum/ApplyMomentum-op12"
],
"support_device": [0,1,2,3,4,5,6,7],
"op_debug_mode": 0
}
}

View File

@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import os import os
import json
import sys import sys
import tempfile import tempfile
import time import time
import shutil import shutil
import glob import glob
from importlib import import_module
from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
import mindspore.context as context import mindspore.context as context
@ -32,6 +32,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import Momentum from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell from mindspore.nn import TrainOneStepCell
from mindspore.nn import WithLossCell from mindspore.nn import WithLossCell
from dump_test_utils import generate_dump_json
class Net(nn.Cell): class Net(nn.Cell):
@ -47,14 +48,6 @@ x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
y = np.array([[7, 8, 9], [10, 11, 12]]).astype(np.float32) y = np.array([[7, 8, 9], [10, 11, 12]]).astype(np.float32)
def change_current_dump_json(file_name, dump_path, dump_config_path):
with open(file_name, 'r+') as f:
data = json.load(f)
data["common_dump_settings"]["path"] = dump_path
with open(dump_config_path, 'w') as f:
json.dump(data, f)
@pytest.mark.level1 @pytest.mark.level1
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@ -65,7 +58,7 @@ def test_async_dump():
with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
dump_path = os.path.join(tmp_dir, 'async_dump') dump_path = os.path.join(tmp_dir, 'async_dump')
dump_config_path = os.path.join(tmp_dir, 'async_dump.json') dump_config_path = os.path.join(tmp_dir, 'async_dump.json')
change_current_dump_json('async_dump.json', dump_path, dump_config_path) generate_dump_json(dump_path, dump_config_path, 'test_async_dump')
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0') dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
if os.path.isdir(dump_path): if os.path.isdir(dump_path):
@ -83,7 +76,7 @@ def run_e2e_dump():
with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
dump_path = os.path.join(tmp_dir, 'e2e_dump') dump_path = os.path.join(tmp_dir, 'e2e_dump')
dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json') dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json')
change_current_dump_json('e2e_dump.json', dump_path, dump_config_path) generate_dump_json(dump_path, dump_config_path, 'test_e2e_dump')
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0') dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
if os.path.isdir(dump_path): if os.path.isdir(dump_path):
@ -178,69 +171,47 @@ class ReluReduceMeanDenseRelu(Cell):
return x_ return x_
def search_path(path, keyword):
content = os.listdir(path)
for each in content:
each_path = path + os.sep + each
if keyword in each:
return each_path
read_write = os.access(each_path, os.W_OK) and os.access(each_path, os.R_OK)
if not read_write:
continue
if os.path.isdir(each_path):
search_path(each_path, keyword)
return None
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_async_dump_net_multi_layer_mode1(): def test_async_dump_net_multi_layer_mode1():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_name = "test_async_dump_net_multi_layer_mode1" pwd = os.getcwd()
json_file = os.path.join(os.getcwd(), "{}.json".format(test_name)) with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
rank_id = 0 dump_path = os.path.join(tmp_dir, 'async_dump_net_multi_layer_mode1')
dump_full_path = os.path.join("/tmp/async_dump/", "{}_{}".format(test_name, rank_id)) json_file_path = os.path.join(tmp_dir, "test_async_dump_net_multi_layer_mode1.json")
os.system("rm -rf {}/*".format(dump_full_path)) generate_dump_json(dump_path, json_file_path, 'test_async_dump_net_multi_layer_mode1')
os.environ["MINDSPORE_DUMP_CONFIG"] = json_file os.environ['MINDSPORE_DUMP_CONFIG'] = json_file_path
weight = Tensor(np.ones((1000, 2048)).astype(np.float32)) weight = Tensor(np.ones((1000, 2048)).astype(np.float32))
bias = Tensor(np.ones((1000,)).astype(np.float32)) bias = Tensor(np.ones((1000,)).astype(np.float32))
net = ReluReduceMeanDenseRelu(weight, bias, 2048, 1000) net = ReluReduceMeanDenseRelu(weight, bias, 2048, 1000)
criterion = SoftmaxCrossEntropyWithLogits(sparse=False) criterion = SoftmaxCrossEntropyWithLogits(sparse=False)
optimizer = Momentum(learning_rate=0.1, momentum=0.1, optimizer = Momentum(learning_rate=0.1, momentum=0.1,
params=filter(lambda x: x.requires_grad, net.get_parameters())) params=filter(lambda x: x.requires_grad, net.get_parameters()))
net_with_criterion = WithLossCell(net, criterion) net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train() train_network.set_train()
inputs = Tensor(np.random.randn(32, 2048, 7, 7).astype(np.float32)) inputs = Tensor(np.random.randn(32, 2048, 7, 7).astype(np.float32))
label = Tensor(np.zeros(shape=(32, 1000)).astype(np.float32)) label = Tensor(np.zeros(shape=(32, 1000)).astype(np.float32))
net_dict = train_network(inputs, label) net_dict = train_network(inputs, label)
dump_file_path = os.path.join(dump_path, 'rank_0', 'test', '0', '0')
dump_path = "/tmp/async_dump/{}/rank_{}/test/0/0/".format(test_name, rank_id) dump_file_name = list(Path(dump_file_path).rglob("*SoftmaxCrossEntropyWithLogits*"))[0]
dump_file = os.listdir(dump_path) dump_file_full_path = os.path.join(dump_file_path, dump_file_name)
dump_file_name = "" npy_path = os.path.join(dump_path, "npy_files")
for file in dump_file: if os.path.exists(npy_path):
if "SoftmaxCrossEntropyWithLogits" in file: shutil.rmtree(npy_path)
dump_file_name = file os.mkdir(npy_path)
dump_file_full_path = os.path.join(dump_path, dump_file_name) tool_path_search_list = list(Path('/usr/local/Ascend').rglob('msaccucmp.py*'))
npy_path = os.path.join(os.getcwd(), "./{}".format(test_name)) if tool_path_search_list:
if os.path.exists(npy_path): converter = import_module("mindspore.offline_debug.convert_async")
shutil.rmtree(npy_path) converter.AsyncDumpConverter([dump_file_full_path], npy_path).convert_files()
os.mkdir(npy_path) npy_result_file = list(Path(npy_path).rglob("*output.0.*.npy"))[0]
tool_path = search_path('/usr/local/Ascend', 'msaccucmp.pyc') dump_result = np.load(os.path.join(npy_path, npy_result_file))
if tool_path: for index, value in enumerate(net_dict):
cmd = "python {0} convert -d {1} -out {2}".format(tool_path, dump_file_full_path, npy_path) assert value.asnumpy() == dump_result[index]
os.system(cmd) else:
npy_file_list = os.listdir(npy_path) print('Failed to find hisi convert tools: msaccucmp.py or msaccucmp.pyc.')
dump_result = {}
for file in npy_file_list:
if "output.0.npy" in file:
dump_result["output0"] = np.load(os.path.join(npy_path, file))
for index, value in enumerate(net_dict):
assert value.asnumpy() == dump_result["output0"][index]
else:
print('not find convert tools msaccucmp.pyc')
@pytest.mark.level0 @pytest.mark.level0
@ -256,7 +227,7 @@ def test_dump_with_diagnostic_path():
pwd = os.getcwd() pwd = os.getcwd()
with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir: with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json') dump_config_path = os.path.join(tmp_dir, 'e2e_dump.json')
change_current_dump_json('e2e_dump.json', '', dump_config_path) generate_dump_json('', dump_config_path, 'test_e2e_dump')
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
diagnose_path = os.path.join(tmp_dir, 'e2e_dump') diagnose_path = os.path.join(tmp_dir, 'e2e_dump')
os.environ['MS_DIAGNOSTIC_DATA_PATH'] = diagnose_path os.environ['MS_DIAGNOSTIC_DATA_PATH'] = diagnose_path