From 9145dd9bbaad7f8c67b54b1211d0f235ef9b9aad Mon Sep 17 00:00:00 2001 From: John Tzanakakis Date: Thu, 2 Dec 2021 14:30:13 -0500 Subject: [PATCH] support op overflow consumption for A dump and A plus M dump --- mindspore/ccsrc/debug/debug_services.cc | 93 +++++++++++++++---------- mindspore/ccsrc/debug/debug_services.h | 7 +- tests/st/debugger/test_watchpoints.py | 14 ---- 3 files changed, 60 insertions(+), 54 deletions(-) diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index 5e0d3f982b2..e8d87b8cefa 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -1552,6 +1552,17 @@ std::vector> DebugServices::GetNodeTensor(const CNod bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int device_id, unsigned int root_graph_id, unsigned int iteration) { + // remove kernel_graph_# + std::string op_name_find_with_path = RemoveKernelGraphPrefix(node_name_to_find); + std::replace(op_name_find_with_path.begin(), op_name_find_with_path.end(), '/', '_'); + + // remove path + size_t last_slash = node_name_to_find.rfind("/"); + std::string op_name_find = ""; + if (last_slash != std::string::npos) { + op_name_find = node_name_to_find.substr(last_slash + 1); + } + std::replace(node_name_to_find.begin(), node_name_to_find.end(), '/', '_'); std::vector op_names; std::string overflow_bin_path; @@ -1614,30 +1625,15 @@ bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int uint64_t stream_id = 0; // detect overflow bin file if (file_name.rfind(overflow_file_prefix, 0) == 0) { - // start of op overflow data in bin file - const uint32_t offset = 321; - (void)infile.seekg(offset, std::ios::beg); - std::vector buffer; - // size of op overflow info section - const size_t buf_size = 256; - buffer.resize(buf_size); - (void)infile.read(buffer.data(), buf_size); - if (infile.gcount() != buf_size) { - MS_LOG(ERROR) << "The file: " << file_path << "may be damaged!"; + if (!GetTaskIdStreamId(file_name, overflow_file_prefix, &task_id, &stream_id)) { continue; } - const uint8_t stream_id_offset = 16; - const uint8_t task_id_offset = 24; - // The stream_id and task_id in the dump file are 8 byte fields for extensibility purpose, but only hold 4 - // byte values currently. - stream_id = BytestoUInt64(std::vector(buffer.begin() + stream_id_offset, buffer.end())); - task_id = BytestoUInt64(std::vector(buffer.begin() + task_id_offset, buffer.end())); MS_LOG(INFO) << "Overflow bin file " << file_name << ", task_id " << task_id << ", stream_id " << stream_id << "."; task_stream_hit.push_back(std::make_pair(task_id, stream_id)); } else { // regular bin file - bool success_parse = GetAttrsFromAsyncFilename(file_name, &node_name, &task_id, &stream_id); + bool success_parse = GetAttrsFromFilename(file_name, &node_name, &task_id, &stream_id); if (success_parse) { task_stream_to_opname[std::make_pair(task_id, stream_id)] = node_name; } @@ -1662,11 +1658,14 @@ bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int overflow_wp_lock_.unlock(); - // remove prefix "kernel_graph_#_" from node_name_to_find before checking it - std::string op_name_to_find = RemoveKernelGraphPrefix(node_name_to_find); + // determine if overflow wp has been triggered for the op name with path (from bin file) + if (find(op_names.begin(), op_names.end(), op_name_find_with_path) != op_names.end()) { + MS_LOG(INFO) << "Operation overflow watchpoint triggered for " << node_name_to_find; + return true; + } - // determine if overflow wp has been triggered for node_name_to_find - if (find(op_names.begin(), op_names.end(), op_name_to_find) != op_names.end()) { + // determine if overflow wp has been triggered for the op name (from npy file) + if (find(op_names.begin(), op_names.end(), op_name_find) != op_names.end()) { MS_LOG(INFO) << "Operation overflow watchpoint triggered for " << node_name_to_find; return true; } @@ -1678,7 +1677,7 @@ std::string DebugServices::RemoveKernelGraphPrefix(std::string node_name_to_find std::string op_name_to_find = node_name_to_find; const std::string kernel_prefix = "kernel_graph_"; if (node_name_to_find.rfind(kernel_prefix, 0) == 0) { - auto start_of_op_name = node_name_to_find.find("_", kernel_prefix.length()); + auto start_of_op_name = node_name_to_find.find("/", kernel_prefix.length()); if (start_of_op_name != std::string::npos) { op_name_to_find = node_name_to_find.substr(start_of_op_name + 1); } @@ -1686,15 +1685,38 @@ std::string DebugServices::RemoveKernelGraphPrefix(std::string node_name_to_find return op_name_to_find; } -bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std::string *const node_name, - uint64_t *task_id, uint64_t *stream_id) { - // get the node_name, task_id, and stream_id from async dump filename - // node_type.node_name.task_id.stram_id.timestamp - // WARNING: node_name may have dots in it - size_t fourth_dot = file_name.rfind("."); - size_t third_dot = file_name.rfind(".", fourth_dot - 1); - size_t second_dot = file_name.rfind(".", third_dot - 1); +bool DebugServices::GetTaskIdStreamId(std::string file_name, std::string overflow_file_prefix, uint64_t *task_id, + uint64_t *stream_id) { + size_t task_pos_start = overflow_file_prefix.length(); + size_t task_pos_end = file_name.find(".", task_pos_start); + if (task_pos_end == std::string::npos) { + MS_LOG(ERROR) << "Cannot extract task_id from filename: " << file_name; + return false; + } + + size_t stream_pos_start = task_pos_end + 1; + size_t stream_pos_end = file_name.find(".", stream_pos_start); + if (stream_pos_end == std::string::npos) { + MS_LOG(ERROR) << "Cannot extract stream_id from filename: " << file_name; + return false; + } + + std::string task_id_str = file_name.substr(task_pos_start, task_pos_end - task_pos_start); + std::string stream_id_str = file_name.substr(stream_pos_start, stream_pos_end - stream_pos_start); + *task_id = std::stoull(task_id_str); + *stream_id = std::stoull(stream_id_str); + + return true; +} + +bool DebugServices::GetAttrsFromFilename(const std::string &file_name, std::string *const node_name, uint64_t *task_id, + uint64_t *stream_id) { + // get the node_name, task_id, and stream_id from dump filename + // node_type.node_name.task_id.stream_id.{etcetera} size_t first_dot = file_name.find("."); + size_t second_dot = file_name.find(".", first_dot + 1); + size_t third_dot = file_name.find(".", second_dot + 1); + size_t fourth_dot = file_name.find(".", third_dot + 1); // check if dots were found if (first_dot == std::string::npos || second_dot == std::string::npos || third_dot == std::string::npos || @@ -1702,16 +1724,11 @@ bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std: return false; } - // check if its not an async bin file - if (file_name.substr(fourth_dot) == ".npy") { - return false; - } - // get node_name if (first_dot < second_dot) { *node_name = file_name.substr(first_dot + 1, second_dot - first_dot - 1); } else { - MS_LOG(ERROR) << "Async filename parse error to get node_name."; + MS_LOG(ERROR) << "filename parse error to get node_name."; return false; } @@ -1728,7 +1745,7 @@ bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std: return false; } } else { - MS_LOG(ERROR) << "Async filename parse error to get task_id."; + MS_LOG(ERROR) << "filename parse error to get task_id."; return false; } @@ -1745,7 +1762,7 @@ bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std: return false; } } else { - MS_LOG(ERROR) << "Async filename parse error to get stream_id."; + MS_LOG(ERROR) << "filename parse error to get stream_id."; return false; } diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h index 0a7646f21d0..ca9d1f17b19 100644 --- a/mindspore/ccsrc/debug/debug_services.h +++ b/mindspore/ccsrc/debug/debug_services.h @@ -449,8 +449,11 @@ class DebugServices { std::string RemoveKernelGraphPrefix(std::string node_name_to_find); - bool GetAttrsFromAsyncFilename(const std::string &file_name, std::string *const node_name, uint64_t *task_id, - uint64_t *stream_id); + bool GetTaskIdStreamId(std::string file_name, std::string overflow_file_prefix, uint64_t *task_id, + uint64_t *stream_id); + + bool GetAttrsFromFilename(const std::string &file_name, std::string *const node_name, uint64_t *task_id, + uint64_t *stream_id); std::string RealPath(const std::string &input_path); diff --git a/tests/st/debugger/test_watchpoints.py b/tests/st/debugger/test_watchpoints.py index bc6c4ad4a9d..ec5321c4049 100644 --- a/tests/st/debugger/test_watchpoints.py +++ b/tests/st/debugger/test_watchpoints.py @@ -216,20 +216,6 @@ def test_async_overflow_watchpoints_hit(): run_overflow_watchpoint(True) -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -@security_off_wrap -def test_async_overflow_watchpoints_not_hit(): - """ - Feature: Offline Debugger CheckWatchpoint - Description: Test check overflow watchpoint hit - Expectation: Overflow watchpoint is not hit - """ - run_overflow_watchpoint(False) - - def compare_expect_actual_result(watchpoint_hits_list, test_index, test_name): """Compare actual result with golden file.""" pwd = os.getcwd()