fix checkwatchpoints multi card issue

This commit is contained in:
Parastoo Ashtari 2021-09-02 13:43:53 -04:00
parent 0fe6e20502
commit a7e977a25e
2 changed files with 29 additions and 11 deletions

View File

@ -158,9 +158,17 @@ void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bo
#endif
void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
const std::string &tensor_name, const std::string &tensor_name_no_slot,
bool *previous_iter_tensor_needed, std::string *const qualified_tensor_name,
const std::shared_ptr<TensorData> &tensor, bool *previous_iter_tensor_needed,
std::string *const qualified_tensor_name,
std::vector<watchpoint_t> *const watchpoints_to_check) {
if (tensor == nullptr) {
MS_LOG(DEBUG) << "tensor is nullptr.";
return;
}
const auto tensor_name = tensor->GetName();
const auto tensor_name_no_slot = tensor_name.substr(0, tensor_name.find_first_of(':'));
const auto tensor_device_id = tensor->GetDeviceId();
const auto tensor_root_graph_id = tensor->GetRootGraphId();
for (auto w_table_item : watchpoint_table) {
auto wp = std::get<1>(w_table_item);
// check ONLY init conditions on initial suspended state.
@ -178,7 +186,7 @@ void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end,
wp_lock_.unlock();
if (wp_cache_hit) continue;
}
std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot);
std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot, tensor_device_id, tensor_root_graph_id);
if (!found.empty()) {
*qualified_tensor_name = found;
watchpoints_to_check->push_back(w_table_item.second);
@ -238,8 +246,8 @@ void DebugServices::CheckWatchpointsForTensor(
bool previous_iter_tensor_needed = false;
// Add do nothing line in case offline debug is off, prevent unused var warning
(void)previous_iter_tensor_needed;
AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor_name, tensor_name_no_slot,
&previous_iter_tensor_needed, &qualified_tensor_name, &watchpoints_to_check);
AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor, &previous_iter_tensor_needed,
&qualified_tensor_name, &watchpoints_to_check);
// no wp set on current tensor
if (watchpoints_to_check.empty()) continue;
uint32_t num_elements = tensor->GetNumElements();

View File

@ -125,16 +125,26 @@ class DebugServices {
std::vector<parameter_t> parameter_list;
size_t location = 0;
std::string FindQualifiedTensorName(const std::string &tensor_name) const {
std::string FindQualifiedTensorName(const std::string &tensor_name, unsigned const int &tensor_device_id,
unsigned const int &tensor_root_graph_id) const {
std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':'));
int indx = 0;
for (auto check_node : check_node_list) {
std::string w_name = std::get<0>(check_node);
bool w_type = std::get<1>(check_node);
auto found = w_name.find_last_of('/');
if (found != std::string::npos && w_name.substr(found + 1) == tensor_name) return w_name;
if ((w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name)) {
return w_name;
bool check_tensor_name = found != std::string::npos && w_name.substr(found + 1) == tensor_name;
bool check_node_name = (w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name);
if (check_tensor_name || check_node_name) {
auto device_vec = std::get<1>(check_node_device_list[indx]);
auto root_graph_vec = std::get<1>(check_node_graph_list[indx]);
auto iter1 = std::find(device_vec.begin(), device_vec.end(), tensor_device_id);
auto iter2 = std::find(root_graph_vec.begin(), root_graph_vec.end(), tensor_root_graph_id);
if (iter1 != device_vec.end() && iter2 != root_graph_vec.end()) {
return w_name;
}
}
indx++;
}
return {};
}
@ -214,8 +224,8 @@ class DebugServices {
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id = nullptr,
std::vector<unsigned int> *root_graph_id = nullptr);
void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, const std::string &tensor_name,
const std::string &tensor_name_no_slot, bool *previous_iter_tensor_needed,
void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
const std::shared_ptr<TensorData> &tensor, bool *previous_iter_tensor_needed,
std::string *qualified_tensor_name, std::vector<watchpoint_t> *watchpoints_to_check);
#ifdef OFFLINE_DBG_MODE