From 1019bfc70f6f30e106163c4f757dca5c44f1decf Mon Sep 17 00:00:00 2001 From: Parastoo Ashtari Date: Thu, 2 Sep 2021 13:43:53 -0400 Subject: [PATCH] fix checkwatchpoints multi card issue --- mindspore/ccsrc/debug/debug_services.cc | 18 +++++++++++++----- mindspore/ccsrc/debug/debug_services.h | 22 ++++++++++++++++------ 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index 9d861c9233b..3d1c54209a5 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -201,9 +201,17 @@ void *DebugServices::GetPrevTensor(const std::shared_ptr &tensor, bo #endif void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, - const std::string &tensor_name, const std::string &tensor_name_no_slot, - bool *previous_iter_tensor_needed, std::string *const qualified_tensor_name, + const std::shared_ptr &tensor, bool *previous_iter_tensor_needed, + std::string *const qualified_tensor_name, std::vector *const watchpoints_to_check) { + if (tensor == nullptr) { + MS_LOG(DEBUG) << "tensor is nullptr."; + return; + } + const auto tensor_name = tensor->GetName(); + const auto tensor_name_no_slot = tensor_name.substr(0, tensor_name.find_first_of(':')); + const auto tensor_device_id = tensor->GetDeviceId(); + const auto tensor_root_graph_id = tensor->GetRootGraphId(); for (auto w_table_item : watchpoint_table_) { auto wp = std::get<1>(w_table_item); // check ONLY init conditions on initial suspended state. @@ -229,7 +237,7 @@ void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, continue; } } - std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot); + std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot, tensor_device_id, tensor_root_graph_id); if (!found.empty()) { *qualified_tensor_name = found; watchpoints_to_check->push_back(w_table_item.second); @@ -296,8 +304,8 @@ void DebugServices::CheckWatchpointsForTensor( bool previous_iter_tensor_needed = false; // Add do nothing line in case offline debug is off, prevent unused var warning (void)previous_iter_tensor_needed; - AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor_name, tensor_name_no_slot, - &previous_iter_tensor_needed, &qualified_tensor_name, &watchpoints_to_check); + AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor, &previous_iter_tensor_needed, + &qualified_tensor_name, &watchpoints_to_check); // no wp set on current tensor if (watchpoints_to_check.empty()) { continue; diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h index c328b88b8e4..98860afb1b5 100644 --- a/mindspore/ccsrc/debug/debug_services.h +++ b/mindspore/ccsrc/debug/debug_services.h @@ -126,16 +126,26 @@ class DebugServices { std::vector parameter_list; size_t location = 0; - std::string FindQualifiedTensorName(const std::string &tensor_name) const { + std::string FindQualifiedTensorName(const std::string &tensor_name, unsigned const int &tensor_device_id, + unsigned const int &tensor_root_graph_id) const { std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':')); + int indx = 0; for (auto check_node : check_node_list) { std::string w_name = std::get<0>(check_node); bool w_type = std::get<1>(check_node); auto found = w_name.find_last_of('/'); - if (found != std::string::npos && w_name.substr(found + 1) == tensor_name) return w_name; - if ((w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name)) { - return w_name; + bool check_tensor_name = found != std::string::npos && w_name.substr(found + 1) == tensor_name; + bool check_node_name = (w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name); + if (check_tensor_name || check_node_name) { + auto device_vec = std::get<1>(check_node_device_list[indx]); + auto root_graph_vec = std::get<1>(check_node_graph_list[indx]); + auto iter1 = std::find(device_vec.begin(), device_vec.end(), tensor_device_id); + auto iter2 = std::find(root_graph_vec.begin(), root_graph_vec.end(), tensor_root_graph_id); + if (iter1 != device_vec.end() && iter2 != root_graph_vec.end()) { + return w_name; + } } + indx++; } return {}; } @@ -276,8 +286,8 @@ class DebugServices { partitioned_id *chunk_root_graph_id, std::vector *device_id, std::vector *root_graph_id); - void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, const std::string &tensor_name, - const std::string &tensor_name_no_slot, bool *previous_iter_tensor_needed, + void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, + const std::shared_ptr &tensor, bool *previous_iter_tensor_needed, std::string *qualified_tensor_name, std::vector *watchpoints_to_check); void SetCheckWatchpointsResult(const int chunk_id, partitioned_names *chunk_names, partitioned_names *chunk_slots,