forked from mindspore-Ecosystem/mindspore
fix checkwatchpoints multi card issue
This commit is contained in:
parent
0fe6e20502
commit
a7e977a25e
|
@ -158,9 +158,17 @@ void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bo
|
|||
#endif
|
||||
|
||||
void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
|
||||
const std::string &tensor_name, const std::string &tensor_name_no_slot,
|
||||
bool *previous_iter_tensor_needed, std::string *const qualified_tensor_name,
|
||||
const std::shared_ptr<TensorData> &tensor, bool *previous_iter_tensor_needed,
|
||||
std::string *const qualified_tensor_name,
|
||||
std::vector<watchpoint_t> *const watchpoints_to_check) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(DEBUG) << "tensor is nullptr.";
|
||||
return;
|
||||
}
|
||||
const auto tensor_name = tensor->GetName();
|
||||
const auto tensor_name_no_slot = tensor_name.substr(0, tensor_name.find_first_of(':'));
|
||||
const auto tensor_device_id = tensor->GetDeviceId();
|
||||
const auto tensor_root_graph_id = tensor->GetRootGraphId();
|
||||
for (auto w_table_item : watchpoint_table) {
|
||||
auto wp = std::get<1>(w_table_item);
|
||||
// check ONLY init conditions on initial suspended state.
|
||||
|
@ -178,7 +186,7 @@ void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end,
|
|||
wp_lock_.unlock();
|
||||
if (wp_cache_hit) continue;
|
||||
}
|
||||
std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot);
|
||||
std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot, tensor_device_id, tensor_root_graph_id);
|
||||
if (!found.empty()) {
|
||||
*qualified_tensor_name = found;
|
||||
watchpoints_to_check->push_back(w_table_item.second);
|
||||
|
@ -238,8 +246,8 @@ void DebugServices::CheckWatchpointsForTensor(
|
|||
bool previous_iter_tensor_needed = false;
|
||||
// Add do nothing line in case offline debug is off, prevent unused var warning
|
||||
(void)previous_iter_tensor_needed;
|
||||
AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor_name, tensor_name_no_slot,
|
||||
&previous_iter_tensor_needed, &qualified_tensor_name, &watchpoints_to_check);
|
||||
AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor, &previous_iter_tensor_needed,
|
||||
&qualified_tensor_name, &watchpoints_to_check);
|
||||
// no wp set on current tensor
|
||||
if (watchpoints_to_check.empty()) continue;
|
||||
uint32_t num_elements = tensor->GetNumElements();
|
||||
|
|
|
@ -125,16 +125,26 @@ class DebugServices {
|
|||
std::vector<parameter_t> parameter_list;
|
||||
size_t location = 0;
|
||||
|
||||
std::string FindQualifiedTensorName(const std::string &tensor_name) const {
|
||||
std::string FindQualifiedTensorName(const std::string &tensor_name, unsigned const int &tensor_device_id,
|
||||
unsigned const int &tensor_root_graph_id) const {
|
||||
std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':'));
|
||||
int indx = 0;
|
||||
for (auto check_node : check_node_list) {
|
||||
std::string w_name = std::get<0>(check_node);
|
||||
bool w_type = std::get<1>(check_node);
|
||||
auto found = w_name.find_last_of('/');
|
||||
if (found != std::string::npos && w_name.substr(found + 1) == tensor_name) return w_name;
|
||||
if ((w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name)) {
|
||||
return w_name;
|
||||
bool check_tensor_name = found != std::string::npos && w_name.substr(found + 1) == tensor_name;
|
||||
bool check_node_name = (w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name);
|
||||
if (check_tensor_name || check_node_name) {
|
||||
auto device_vec = std::get<1>(check_node_device_list[indx]);
|
||||
auto root_graph_vec = std::get<1>(check_node_graph_list[indx]);
|
||||
auto iter1 = std::find(device_vec.begin(), device_vec.end(), tensor_device_id);
|
||||
auto iter2 = std::find(root_graph_vec.begin(), root_graph_vec.end(), tensor_root_graph_id);
|
||||
if (iter1 != device_vec.end() && iter2 != root_graph_vec.end()) {
|
||||
return w_name;
|
||||
}
|
||||
}
|
||||
indx++;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
@ -214,8 +224,8 @@ class DebugServices {
|
|||
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id = nullptr,
|
||||
std::vector<unsigned int> *root_graph_id = nullptr);
|
||||
|
||||
void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, const std::string &tensor_name,
|
||||
const std::string &tensor_name_no_slot, bool *previous_iter_tensor_needed,
|
||||
void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
|
||||
const std::shared_ptr<TensorData> &tensor, bool *previous_iter_tensor_needed,
|
||||
std::string *qualified_tensor_name, std::vector<watchpoint_t> *watchpoints_to_check);
|
||||
|
||||
#ifdef OFFLINE_DBG_MODE
|
||||
|
|
Loading…
Reference in New Issue