change validator to accept -1 as valid input and fix scope filename issue in async

This commit is contained in:
TinaMengtingZhang 2021-06-17 14:09:28 -04:00
parent 1c991331b9
commit dc6a661a2e
4 changed files with 25 additions and 17 deletions

View File

@ -113,7 +113,7 @@ void int_to_byte(size_t number, char *byte, size_t length) {
std::string GenerateNpyHeader(const ShapeVector &shape, TypeId type_id, bool fortran_order) {
auto type_desc = type_desc_map.find(type_id);
if (type_desc == type_desc_map.end()) {
MS_LOG(WARNING) << "Not support dump the " << TypeIdToType(type_id)->ToString() << " data to npy file.";
MS_LOG(INFO) << "Not support dump the " << TypeIdToType(type_id)->ToString() << " data to npy file.";
return std::string();
}

View File

@ -636,16 +636,22 @@ void DebugServices::AddToTensorData(const std::string &backend_name, const std::
void DebugServices::SetPrefixToCheck(std::string *prefix_dump_file_name, std::string *dump_style_kernel_name,
size_t slot) {
std::string dump_style_name_part = *dump_style_kernel_name;
std::size_t last_scope_marker;
std::string delim;
if (is_sync_mode) {
std::string dump_style_name_part = *dump_style_kernel_name;
std::size_t last_scope_marker = dump_style_kernel_name->rfind("--");
if (last_scope_marker != std::string::npos) {
dump_style_name_part = dump_style_kernel_name->substr(last_scope_marker + 2);
}
*prefix_dump_file_name = dump_style_name_part + ".output." + std::to_string(slot);
delim = "--";
} else {
*prefix_dump_file_name = *dump_style_kernel_name;
delim = "_";
}
last_scope_marker = dump_style_kernel_name->rfind(delim);
if (last_scope_marker != std::string::npos) {
dump_style_name_part = dump_style_kernel_name->substr(last_scope_marker + delim.size());
}
if (is_sync_mode) {
dump_style_name_part += ".output." + std::to_string(slot);
}
*prefix_dump_file_name = dump_style_name_part;
}
void DebugServices::ReadDumpedTensor(std::vector<std::string> backend_name, std::vector<size_t> slot,
@ -728,10 +734,7 @@ void DebugServices::ReadDumpedTensor(std::vector<std::string> backend_name, std:
bool found = false;
// if async mode
for (const std::string &file_path : async_file_pool) {
size_t first_dot_pos = file_path.find('.');
size_t second_dot_pos = file_path.find('.', first_dot_pos + 1);
std::string file_path_op_name = file_path.substr(first_dot_pos + 1, second_dot_pos - first_dot_pos - 1);
if (file_path_op_name == prefix_dump_file_name &&
if (file_path.find(prefix_dump_file_name) != std::string::npos &&
file_path.find(".output." + std::to_string(slot[i])) != std::string::npos) {
found = true;
shape.clear();
@ -761,8 +764,8 @@ std::string DebugServices::GetStrippedFilename(const std::string &file_name) {
// Look for the second dot's position from the back to avoid issue due to dots in the node name.
size_t second_dot = fifth_dot;
const int8_t kSeconeDotPosition = 2;
for (int8_t pos = 5; pos > kSeconeDotPosition; pos--) {
const int8_t kSecondDotPosition = 2;
for (int8_t pos = 5; pos > kSecondDotPosition; pos--) {
second_dot = file_name.rfind(".", second_dot - 1);
}

View File

@ -48,6 +48,11 @@ def check_uint64(arg, arg_name=""):
check_value(arg, [UINT64_MIN, UINT64_MAX])
def check_iteration(arg, arg_name=""):
type_check(arg, (int,), arg_name)
check_value(arg, [-1, UINT64_MAX])
def check_dir(dataset_dir):
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))

View File

@ -18,7 +18,7 @@ Validator Functions for Offline Debugger APIs.
from functools import wraps
import mindspore.offline_debug.dbg_services as cds
from mindspore.offline_debug.mi_validator_helpers import parse_user_args, type_check, type_check_list, check_dir, check_uint32, check_uint64
from mindspore.offline_debug.mi_validator_helpers import parse_user_args, type_check, type_check_list, check_dir, check_uint32, check_uint64, check_iteration
def check_init(method):
@ -114,7 +114,7 @@ def check_check_watchpoints(method):
def new_method(self, *args, **kwargs):
[iteration], _ = parse_user_args(method, *args, **kwargs)
check_uint32(iteration, "iteration")
check_iteration(iteration, "iteration")
return method(self, *args, **kwargs)
@ -159,7 +159,7 @@ def check_tensor_info_init(method):
type_check(node_name, (str,), "node_name")
check_uint32(slot, "slot")
check_uint32(iteration, "iteration")
check_iteration(iteration, "iteration")
check_uint32(device_id, "device_id")
check_uint32(root_graph_id, "root_graph_id")
type_check(is_parameter, (bool,), "is_parameter")