support op overflow consumption for A dump and A plus M dump

This commit is contained in:
John Tzanakakis 2021-12-02 14:30:13 -05:00
parent 9ddd3fb66f
commit 9145dd9bba
3 changed files with 60 additions and 54 deletions

View File

@ -1552,6 +1552,17 @@ std::vector<std::shared_ptr<TensorData>> DebugServices::GetNodeTensor(const CNod
bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int device_id, unsigned int root_graph_id,
unsigned int iteration) {
// remove kernel_graph_#
std::string op_name_find_with_path = RemoveKernelGraphPrefix(node_name_to_find);
std::replace(op_name_find_with_path.begin(), op_name_find_with_path.end(), '/', '_');
// remove path
size_t last_slash = node_name_to_find.rfind("/");
std::string op_name_find = "";
if (last_slash != std::string::npos) {
op_name_find = node_name_to_find.substr(last_slash + 1);
}
std::replace(node_name_to_find.begin(), node_name_to_find.end(), '/', '_');
std::vector<std::string> op_names;
std::string overflow_bin_path;
@ -1614,30 +1625,15 @@ bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int
uint64_t stream_id = 0;
// detect overflow bin file
if (file_name.rfind(overflow_file_prefix, 0) == 0) {
// start of op overflow data in bin file
const uint32_t offset = 321;
(void)infile.seekg(offset, std::ios::beg);
std::vector<char> buffer;
// size of op overflow info section
const size_t buf_size = 256;
buffer.resize(buf_size);
(void)infile.read(buffer.data(), buf_size);
if (infile.gcount() != buf_size) {
MS_LOG(ERROR) << "The file: " << file_path << "may be damaged!";
if (!GetTaskIdStreamId(file_name, overflow_file_prefix, &task_id, &stream_id)) {
continue;
}
const uint8_t stream_id_offset = 16;
const uint8_t task_id_offset = 24;
// The stream_id and task_id in the dump file are 8 byte fields for extensibility purpose, but only hold 4
// byte values currently.
stream_id = BytestoUInt64(std::vector<char>(buffer.begin() + stream_id_offset, buffer.end()));
task_id = BytestoUInt64(std::vector<char>(buffer.begin() + task_id_offset, buffer.end()));
MS_LOG(INFO) << "Overflow bin file " << file_name << ", task_id " << task_id << ", stream_id " << stream_id
<< ".";
task_stream_hit.push_back(std::make_pair(task_id, stream_id));
} else {
// regular bin file
bool success_parse = GetAttrsFromAsyncFilename(file_name, &node_name, &task_id, &stream_id);
bool success_parse = GetAttrsFromFilename(file_name, &node_name, &task_id, &stream_id);
if (success_parse) {
task_stream_to_opname[std::make_pair(task_id, stream_id)] = node_name;
}
@ -1662,11 +1658,14 @@ bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int
overflow_wp_lock_.unlock();
// remove prefix "kernel_graph_#_" from node_name_to_find before checking it
std::string op_name_to_find = RemoveKernelGraphPrefix(node_name_to_find);
// determine if overflow wp has been triggered for the op name with path (from bin file)
if (find(op_names.begin(), op_names.end(), op_name_find_with_path) != op_names.end()) {
MS_LOG(INFO) << "Operation overflow watchpoint triggered for " << node_name_to_find;
return true;
}
// determine if overflow wp has been triggered for node_name_to_find
if (find(op_names.begin(), op_names.end(), op_name_to_find) != op_names.end()) {
// determine if overflow wp has been triggered for the op name (from npy file)
if (find(op_names.begin(), op_names.end(), op_name_find) != op_names.end()) {
MS_LOG(INFO) << "Operation overflow watchpoint triggered for " << node_name_to_find;
return true;
}
@ -1678,7 +1677,7 @@ std::string DebugServices::RemoveKernelGraphPrefix(std::string node_name_to_find
std::string op_name_to_find = node_name_to_find;
const std::string kernel_prefix = "kernel_graph_";
if (node_name_to_find.rfind(kernel_prefix, 0) == 0) {
auto start_of_op_name = node_name_to_find.find("_", kernel_prefix.length());
auto start_of_op_name = node_name_to_find.find("/", kernel_prefix.length());
if (start_of_op_name != std::string::npos) {
op_name_to_find = node_name_to_find.substr(start_of_op_name + 1);
}
@ -1686,15 +1685,38 @@ std::string DebugServices::RemoveKernelGraphPrefix(std::string node_name_to_find
return op_name_to_find;
}
bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std::string *const node_name,
uint64_t *task_id, uint64_t *stream_id) {
// get the node_name, task_id, and stream_id from async dump filename
// node_type.node_name.task_id.stram_id.timestamp
// WARNING: node_name may have dots in it
size_t fourth_dot = file_name.rfind(".");
size_t third_dot = file_name.rfind(".", fourth_dot - 1);
size_t second_dot = file_name.rfind(".", third_dot - 1);
bool DebugServices::GetTaskIdStreamId(std::string file_name, std::string overflow_file_prefix, uint64_t *task_id,
uint64_t *stream_id) {
size_t task_pos_start = overflow_file_prefix.length();
size_t task_pos_end = file_name.find(".", task_pos_start);
if (task_pos_end == std::string::npos) {
MS_LOG(ERROR) << "Cannot extract task_id from filename: " << file_name;
return false;
}
size_t stream_pos_start = task_pos_end + 1;
size_t stream_pos_end = file_name.find(".", stream_pos_start);
if (stream_pos_end == std::string::npos) {
MS_LOG(ERROR) << "Cannot extract stream_id from filename: " << file_name;
return false;
}
std::string task_id_str = file_name.substr(task_pos_start, task_pos_end - task_pos_start);
std::string stream_id_str = file_name.substr(stream_pos_start, stream_pos_end - stream_pos_start);
*task_id = std::stoull(task_id_str);
*stream_id = std::stoull(stream_id_str);
return true;
}
bool DebugServices::GetAttrsFromFilename(const std::string &file_name, std::string *const node_name, uint64_t *task_id,
uint64_t *stream_id) {
// get the node_name, task_id, and stream_id from dump filename
// node_type.node_name.task_id.stream_id.{etcetera}
size_t first_dot = file_name.find(".");
size_t second_dot = file_name.find(".", first_dot + 1);
size_t third_dot = file_name.find(".", second_dot + 1);
size_t fourth_dot = file_name.find(".", third_dot + 1);
// check if dots were found
if (first_dot == std::string::npos || second_dot == std::string::npos || third_dot == std::string::npos ||
@ -1702,16 +1724,11 @@ bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std:
return false;
}
// check if its not an async bin file
if (file_name.substr(fourth_dot) == ".npy") {
return false;
}
// get node_name
if (first_dot < second_dot) {
*node_name = file_name.substr(first_dot + 1, second_dot - first_dot - 1);
} else {
MS_LOG(ERROR) << "Async filename parse error to get node_name.";
MS_LOG(ERROR) << "filename parse error to get node_name.";
return false;
}
@ -1728,7 +1745,7 @@ bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std:
return false;
}
} else {
MS_LOG(ERROR) << "Async filename parse error to get task_id.";
MS_LOG(ERROR) << "filename parse error to get task_id.";
return false;
}
@ -1745,7 +1762,7 @@ bool DebugServices::GetAttrsFromAsyncFilename(const std::string &file_name, std:
return false;
}
} else {
MS_LOG(ERROR) << "Async filename parse error to get stream_id.";
MS_LOG(ERROR) << "filename parse error to get stream_id.";
return false;
}

View File

@ -449,8 +449,11 @@ class DebugServices {
std::string RemoveKernelGraphPrefix(std::string node_name_to_find);
bool GetAttrsFromAsyncFilename(const std::string &file_name, std::string *const node_name, uint64_t *task_id,
uint64_t *stream_id);
bool GetTaskIdStreamId(std::string file_name, std::string overflow_file_prefix, uint64_t *task_id,
uint64_t *stream_id);
bool GetAttrsFromFilename(const std::string &file_name, std::string *const node_name, uint64_t *task_id,
uint64_t *stream_id);
std::string RealPath(const std::string &input_path);

View File

@ -216,20 +216,6 @@ def test_async_overflow_watchpoints_hit():
run_overflow_watchpoint(True)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_async_overflow_watchpoints_not_hit():
"""
Feature: Offline Debugger CheckWatchpoint
Description: Test check overflow watchpoint hit
Expectation: Overflow watchpoint is not hit
"""
run_overflow_watchpoint(False)
def compare_expect_actual_result(watchpoint_hits_list, test_index, test_name):
"""Compare actual result with golden file."""
pwd = os.getcwd()