diff --git a/mindspore/ccsrc/debug/CMakeLists.txt b/mindspore/ccsrc/debug/CMakeLists.txt index b57f3288bbc..64a5e23b18d 100644 --- a/mindspore/ccsrc/debug/CMakeLists.txt +++ b/mindspore/ccsrc/debug/CMakeLists.txt @@ -20,6 +20,7 @@ set(_OFFLINE_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/debugger/offline_debug/dbg_services.cc" "${CMAKE_SOURCE_DIR}/mindspore/core/utils/log_adapter.cc" "${CMAKE_CURRENT_SOURCE_DIR}/debugger/offline_debug/mi_pybind_register.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/utils.cc" ) if(ENABLE_DUMP_IR) @@ -43,6 +44,7 @@ if(ENABLE_DEBUGGER) "${CMAKE_CURRENT_SOURCE_DIR}/debug_services.cc" "${CMAKE_CURRENT_SOURCE_DIR}/debugger/debugger_utils.cc" "${CMAKE_CURRENT_SOURCE_DIR}/data_dump/tensor_stat_dump.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/utils.cc" ) endif() @@ -52,6 +54,7 @@ if(NOT ENABLE_SECURITY) "${CMAKE_CURRENT_SOURCE_DIR}/data_dump/dump_json_parser.cc" "${CMAKE_CURRENT_SOURCE_DIR}/data_dump/dump_utils.cc" "${CMAKE_CURRENT_SOURCE_DIR}/data_dump/npy_header.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/utils.cc" ) if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") list(APPEND _DEBUG_SRC_LIST diff --git a/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc b/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc index 811631c338f..cee909ed9e3 100644 --- a/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc +++ b/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc @@ -17,6 +17,7 @@ #include #include "utils/log_adapter.h" #include "debug/common.h" +#include "debug/utils.h" #include "utils/ms_context.h" #include "utils/convert_utils_base.h" #include "backend/common/session/anf_runtime_algorithm.h" @@ -465,7 +466,13 @@ bool IsIterInRange(uint32_t iteration, const std::string &range) { std::size_t range_idx = range.find(dash); // no dash in range, compare the value directly if (range_idx == std::string::npos) { - return iteration == std::stoul(range); + size_t range_d = 0; + if (!CheckStoul(&range_d, range)) { + MS_LOG(INFO) << "Failed to convert the single step range: " << range + << " into an integer, so the iteration: " << iteration << " is regarded as not in dump range."; + return false; + } + return iteration == range_d; } // make sure there is only one dash in range if (range.find(dash, range_idx + 1) != std::string::npos) { @@ -476,8 +483,18 @@ bool IsIterInRange(uint32_t iteration, const std::string &range) { if (low_range_str.empty() || high_range_str.empty()) { return false; } - uint32_t low_range = static_cast(std::stoul(low_range_str)); - uint32_t high_range = static_cast(std::stoul(high_range_str)); + size_t low_range = 0; + if (!CheckStoul(&low_range, low_range_str)) { + MS_LOG(INFO) << "Failed to convert the low_range_str: " << low_range_str + << " into an integer, so the iteration: " << iteration << " is regarded as not in dump range."; + return false; + } + size_t high_range = 0; + if (!CheckStoul(&high_range, high_range_str)) { + MS_LOG(INFO) << "Failed to convert the high_range_str: " << high_range_str + << " into an integer, so the iteration: " << iteration << " is regarded as not in dump range."; + return false; + } return (low_range <= iteration) && (iteration <= high_range); } diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index fb29367ac44..44a32e21c01 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -34,13 +34,13 @@ #include "debug/anf_ir_utils.h" #include "backend/common/session/anf_runtime_algorithm.h" #endif +#include "debug/utils.h" #include "nlohmann/json.hpp" #include "debug/debugger/tensor_summary.h" #include "utils/file_utils.h" #include "climits" -#ifdef ONLINE_DBG_MODE + namespace mindspore { -#endif static constexpr const char *constant_prefix = "Default--data-"; static constexpr const char *kNpyExt = ".npy"; @@ -90,8 +90,8 @@ DebugServices &DebugServices::operator=(const DebugServices &other) { * watchpoint_table. */ void DebugServices::AddWatchpoint( - unsigned int id, int watch_condition, float parameter, - const std::vector> &check_node_list, const std::vector ¶meter_list, + int id, int watch_condition, float parameter, const std::vector> &check_node_list, + const std::vector ¶meter_list, const std::vector>> *check_node_device_list, const std::vector>> *check_node_graph_list) { std::lock_guard lg(lock_); @@ -732,7 +732,13 @@ void DebugServices::SortWatchpointsInfo( std::vector().swap((*chunk_error_codes)[i]); std::vector().swap((*chunk_device_id)[i]); std::vector().swap((*chunk_root_graph_id)[i]); - (*tensor_list_byte_size) += (*chunk_tensor_byte_size)[i]; + if ((*tensor_list_byte_size) > ULONG_LONG_MAX - (*chunk_tensor_byte_size)[i]) { + MS_LOG(WARNING) << (*tensor_list_byte_size) << " + " << (*chunk_tensor_byte_size)[i] + << " would lead to integer overflow!"; + (*tensor_list_byte_size) = ULONG_LONG_MAX; + } else { + (*tensor_list_byte_size) += (*chunk_tensor_byte_size)[i]; + } } } @@ -800,9 +806,20 @@ void DebugServices::ReadTensorFromNpy(const std::string &tensor_name, const std: std::stringstream check_shape(shape_str); MS_LOG(INFO) << "Shape of " << file_name << " is: [" << shape_str << "]"; while (getline(check_shape, intermediate, ',')) { - shape->push_back(std::stoi(intermediate)); + int64_t shape_d = 0; + if (!CheckStoi(&shape_d, intermediate)) { + MS_LOG(INFO) << "Failed to get the shape from file: " << file_name << ", error in convert the string " + << intermediate << " into an integer."; + return; + } + shape->push_back(shape_d); + } + std::size_t word_size = 0; + if (!CheckStoul(&word_size, std::string(1, (*tensor_type)[1]))) { + MS_LOG(INFO) << "Failed to get the word_size from file: " << file_name << ", error in convert the string " + << (*tensor_type)[1] << " into an integer."; + return; } - std::size_t word_size = std::stoul(std::string(1, (*tensor_type)[1])); std::size_t data_len = std::accumulate(shape->begin(), shape->end(), 1, std::multiplies()); std::size_t data_size = data_len * word_size; if (!data_size) { @@ -880,25 +897,26 @@ void DebugServices::ProcessConvertToHostFormat(const std::vector &f std::string real_dump_iter_dir = RealPath(dump_key); DIR *d_handle = opendir(real_dump_iter_dir.c_str()); if (d_handle == nullptr) { - MS_LOG(INFO) << "Directory does not exist in ConvertToHostFormat."; + MS_LOG(INFO) << "Directory " << real_dump_iter_dir << " does not exist in ConvertToHostFormat."; return; } struct dirent *dir = nullptr; while ((dir = readdir(d_handle)) != nullptr) { std::string name = real_dump_iter_dir + std::string("/") + std::string(dir->d_name); - if (IsRegFile(name)) { - std::string candidate = dir->d_name; - for (const std::string &file_to_find : files_after_convert_in_dir) { - std::string file_n = file_to_find; - auto last_slash_pos = file_to_find.find_last_of("\\/"); - if (last_slash_pos != std::string::npos) { - file_n = file_to_find.substr(last_slash_pos + 1); - } - if (candidate.find(file_n + ".") != std::string::npos && candidate.rfind(kNpyExt) != std::string::npos) { - // we found a converted file for this op - std::string found_file = dump_key + "/" + candidate; - (void)result_list->insert(found_file); - } + if (!IsRegFile(name)) { + continue; + } + std::string candidate = dir->d_name; + for (const std::string &file_to_find : files_after_convert_in_dir) { + std::string file_n = file_to_find; + auto last_slash_pos = file_to_find.find_last_of("\\/"); + if (last_slash_pos != std::string::npos) { + file_n = file_to_find.substr(last_slash_pos + 1); + } + if (candidate.find(file_n + ".") != std::string::npos && candidate.rfind(kNpyExt) != std::string::npos) { + // we found a converted file for this op + std::string found_file = dump_key + "/" + candidate; + (void)result_list->insert(found_file); } } } @@ -1046,8 +1064,14 @@ void DebugServices::GetTensorDataInfoAsync(const std::vectord_name; - std::string file_path = specific_dump_dir + std::string("/") + file_name; - if (IsRegFile(file_path)) { - for (auto &node : proto_to_dump) { - std::string dump_name = std::get<1>(node); - std::string stripped_file_name = GetStrippedFilename(file_name); - if (stripped_file_name.empty() || stripped_file_name.length() <= dump_name.length()) { + return; + } + struct dirent *dir = nullptr; + while ((dir = readdir(d)) != nullptr) { + std::string file_name = dir->d_name; + std::string file_path = specific_dump_dir + std::string("/") + file_name; + if (IsRegFile(file_path)) { + for (auto &node : proto_to_dump) { + std::string dump_name = std::get<1>(node); + std::string stripped_file_name = GetStrippedFilename(file_name); + if (stripped_file_name.empty() || stripped_file_name.length() <= dump_name.length()) { + continue; + } + std::size_t found = stripped_file_name.rfind(dump_name + ".", 0); + if (found == 0) { + size_t slot = 0; + if (!CheckStoul(&slot, stripped_file_name.substr(dump_name.length() + 1))) { + MS_LOG(INFO) << "Failed to get the slot from file_name: " << file_name << ", error in convert the string " + << stripped_file_name.substr(dump_name.length() + 1) << " into an integer."; continue; } - std::size_t found = stripped_file_name.rfind(dump_name + ".", 0); - if (found == 0) { - size_t slot = std::stoul(stripped_file_name.substr(dump_name.length() + 1)); - std::vector shape; - std::string orig_name = std::get<0>(node); - std::string output_str = dump_name.substr(dump_name.rfind(".") + 1); - bool output_flag = (output_str == "output"); + std::vector shape; + std::string orig_name = std::get<0>(node); + std::string output_str = dump_name.substr(dump_name.rfind(".") + 1); + bool output_flag = (output_str == "output"); - AddToTensorData(orig_name, "", slot, iteration, device_id, root_graph_id, output_flag, 0, "", shape, - nullptr, tensor_list); - break; - } + AddToTensorData(orig_name, "", slot, iteration, device_id, root_graph_id, output_flag, 0, "", shape, nullptr, + tensor_list); + break; } } } - (void)closedir(d); } + (void)closedir(d); } std::string DebugServices::IterationString(unsigned int iteration) { @@ -2018,8 +2047,16 @@ bool DebugServices::GetTaskIdStreamId(std::string file_name, std::string overflo std::string task_id_str = file_name.substr(task_pos_start, task_pos_end - task_pos_start); std::string stream_id_str = file_name.substr(stream_pos_start, stream_pos_end - stream_pos_start); - *task_id = std::stoull(task_id_str); - *stream_id = std::stoull(stream_id_str); + if (!CheckStoull(task_id, task_id_str)) { + MS_LOG(INFO) << "Failed to get the task_id from file_name: " << file_name << ", error in convert the string " + << task_id_str << " into an integer."; + return false; + } + if (!CheckStoull(stream_id, stream_id_str)) { + MS_LOG(INFO) << "Failed to get the stream_id from file_name: " << file_name << ", error in convert the string " + << stream_id_str << " into an integer."; + return false; + } return true; } @@ -2063,13 +2100,9 @@ bool DebugServices::GetAttrsFromFilename(const std::string &file_name, std::stri // get task id if (second_dot < third_dot) { std::string extracted_task_id = file_name.substr(second_dot + 1, third_dot - second_dot - 1); - try { - *task_id = std::stoull(extracted_task_id); - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "stoull failed on extracted_task_id to get task_id, invalid argument."; - return false; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "stoull failed on extracted_task_id to get task_id, out of range."; + if (!CheckStoull(task_id, extracted_task_id)) { + MS_LOG(INFO) << "Failed to get the task_id from file_name: " << file_name << ", error in convert the string " + << extracted_task_id << " into an integer."; return false; } } else { @@ -2079,13 +2112,9 @@ bool DebugServices::GetAttrsFromFilename(const std::string &file_name, std::stri // get stream id if (third_dot < fourth_dot) { std::string extracted_stream_id = file_name.substr(third_dot + 1, fourth_dot - third_dot - 1); - try { - *stream_id = std::stoull(extracted_stream_id); - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "stoull failed on extracted_stream_id to get stream_id, invalid argument."; - return false; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "stoull failed on extracted_stream_id to get stream_id, out of range."; + if (!CheckStoull(stream_id, extracted_stream_id)) { + MS_LOG(INFO) << "Failed to get the stream_id from file_name: " << file_name << ", error in convert the string " + << extracted_stream_id << " into an integer."; return false; } } else { @@ -2168,6 +2197,4 @@ bool DebugServices::GetSyncMode() { return is_sync_mode_; } void DebugServices::SetMemLimit(uint64_t max_mem_size) { tensor_loader_->SetMemTotal(max_mem_size); } -#ifdef ONLINE_DBG_MODE } // namespace mindspore -#endif diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h index 2d21a2471c5..a35018c00c2 100644 --- a/mindspore/ccsrc/debug/debug_services.h +++ b/mindspore/ccsrc/debug/debug_services.h @@ -40,9 +40,7 @@ #include "debug/tensor_load.h" #include "debug/tensor_data.h" -#ifdef ONLINE_DBG_MODE namespace mindspore { -#endif class DebugServices { public: DebugServices(); @@ -242,8 +240,8 @@ class DebugServices { static TensorStat GetTensorStatistics(const std::shared_ptr &tensor); void AddWatchpoint( - unsigned int id, int watch_condition, float parameter, - const std::vector> &check_node_list, const std::vector ¶meter_list, + int id, int watch_condition, float parameter, const std::vector> &check_node_list, + const std::vector ¶meter_list, const std::vector>> *check_node_device_list = nullptr, const std::vector>> *check_node_graph_list = nullptr); @@ -496,8 +494,6 @@ class DebugServices { std::shared_ptr tensor_loader_; }; -#ifdef ONLINE_DBG_MODE } // namespace mindspore -#endif #endif // MINDSPORE_CCSRC_DEBUG_DEBUG_SERVICES_H_ diff --git a/mindspore/ccsrc/debug/debugger/debug_grpc.proto b/mindspore/ccsrc/debug/debugger/debug_grpc.proto index 499967ffd6c..2e933325594 100644 --- a/mindspore/ccsrc/debug/debugger/debug_grpc.proto +++ b/mindspore/ccsrc/debug/debugger/debug_grpc.proto @@ -161,13 +161,13 @@ message Statistics { float max_value = 2; float min_value = 3; float avg_value = 4; - int32 count = 5; - int32 neg_zero_count = 6; - int32 pos_zero_count = 7; - int32 nan_count = 8; - int32 neg_inf_count = 9; - int32 pos_inf_count = 10; - int32 zero_count = 11; + uint64 count = 5; + uint64 neg_zero_count = 6; + uint64 pos_zero_count = 7; + uint64 nan_count = 8; + uint64 neg_inf_count = 9; + uint64 pos_inf_count = 10; + uint64 zero_count = 11; } message TensorBase{ diff --git a/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.cc b/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.cc index 76bdff640df..3eb22eb9b28 100644 --- a/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.cc +++ b/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.cc @@ -17,6 +17,7 @@ #include #include +#include "debug/utils.h" namespace mindspore { DbgServices::DbgServices() { debug_services_ = std::make_shared(); } @@ -77,7 +78,7 @@ int32_t DbgServices::Initialize(const std::string net_name, const std::string du } int32_t DbgServices::AddWatchpoint( - unsigned int id, int watch_condition, + int id, int watch_condition, std::map>>> check_nodes, std::vector parameter_list) { MS_EXCEPTION_IF_NULL(debug_services_); @@ -94,9 +95,14 @@ int32_t DbgServices::AddWatchpoint( std::vector rank_id_str = std::get>(attr_map["rank_id"]); std::vector rank_id; - (void)std::transform( - rank_id_str.begin(), rank_id_str.end(), std::back_inserter(rank_id), - [](const std::string &id_str) -> std::uint32_t { return static_cast(std::stoul(id_str)); }); + (void)std::transform(rank_id_str.begin(), rank_id_str.end(), std::back_inserter(rank_id), + [](const std::string &id_str) -> std::uint32_t { + size_t id_inter = 0; + if (!CheckStoul(&id_inter, id_str)) { + MS_LOG(EXCEPTION) << "Failed to extract rand_id!"; + } + return static_cast(id_inter); + }); MS_LOG(DEBUG) << "cpp DbgServices AddWatchpoint rank_id: "; for (auto const &i : rank_id) { MS_LOG(DEBUG) << i << " "; @@ -104,9 +110,14 @@ int32_t DbgServices::AddWatchpoint( std::vector root_graph_id_str = std::get>(attr_map["root_graph_id"]); std::vector root_graph_id; - (void)std::transform( - root_graph_id_str.begin(), root_graph_id_str.end(), std::back_inserter(root_graph_id), - [](const std::string &graph_str) -> std::uint32_t { return static_cast(std::stoul(graph_str)); }); + (void)std::transform(root_graph_id_str.begin(), root_graph_id_str.end(), std::back_inserter(root_graph_id), + [](const std::string &graph_str) -> std::uint32_t { + size_t graph_inter = 0; + if (!CheckStoul(&graph_inter, graph_str)) { + MS_LOG(EXCEPTION) << "Failed to extract graph_id!"; + } + return static_cast(graph_inter); + }); MS_LOG(DEBUG) << "cpp DbgServices AddWatchpoint root_graph_id: "; for (auto const &j : root_graph_id) { MS_LOG(DEBUG) << j << " "; @@ -139,7 +150,11 @@ int32_t DbgServices::AddWatchpoint( std::vector rank_id; (void)std::transform(rank_id_str.begin(), rank_id_str.end(), std::back_inserter(rank_id), [](std::string &id_str) -> std::uint32_t { - return static_cast(std::stoul(id_str)); + size_t id_inter = 0; + if (!CheckStoul(&id_inter, id_str)) { + MS_LOG(EXCEPTION) << "Failed to extract rand_id!"; + } + return static_cast(id_inter); }); return std::make_tuple(node.first, rank_id); }); @@ -150,9 +165,14 @@ int32_t DbgServices::AddWatchpoint( auto attr_map = node.second; std::vector root_graph_id_str = std::get>(attr_map["root_graph_id"]); std::vector root_graph_id; - (void)std::transform( - root_graph_id_str.begin(), root_graph_id_str.end(), std::back_inserter(root_graph_id), - [](std::string &graph_str) -> std::uint32_t { return static_cast(std::stoul(graph_str)); }); + (void)std::transform(root_graph_id_str.begin(), root_graph_id_str.end(), std::back_inserter(root_graph_id), + [](std::string &graph_str) -> std::uint32_t { + size_t graph_inter = 0; + if (!CheckStoul(&graph_inter, graph_str)) { + MS_LOG(EXCEPTION) << "Failed to extract graph_id!"; + } + return static_cast(graph_inter); + }); return std::make_tuple(node.first, root_graph_id); }); @@ -204,8 +224,12 @@ std::vector DbgServices::CheckWatchpoints(unsigned int iterati parameter_t api_parameter(p.name, p.disabled, p.value, p.hit, p.actual_value); api_parameter_vector.push_back(api_parameter); } - watchpoint_hit_t hit(name[i], std::stoi(slot[i]), condition[i], watchpoint_id[i], api_parameter_vector, - error_codes[i], rank_id[i], root_graph_id[i]); + size_t slot_inter = 0; + if (!CheckStoul(&slot_inter, slot[i])) { + MS_LOG(EXCEPTION) << "Failed to extract slot_id!"; + } + watchpoint_hit_t hit(name[i], static_cast(slot_inter), condition[i], watchpoint_id[i], + api_parameter_vector, error_codes[i], rank_id[i], root_graph_id[i]); MS_LOG(DEBUG) << "cpp DbgServices watchpoint_hit_t name " << hit.name; MS_LOG(DEBUG) << "cpp DbgServices watchpoint_hit_t slot " << hit.slot; diff --git a/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.h b/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.h index cc4eb08c005..6ab1d73c81e 100644 --- a/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.h +++ b/mindspore/ccsrc/debug/debugger/offline_debug/dbg_services.h @@ -196,7 +196,7 @@ class DbgServices { uint64_t max_mem_usage); int32_t AddWatchpoint( - unsigned int id, int watch_condition, + int id, int watch_condition, std::map>>> check_nodes, std::vector parameter_list); diff --git a/mindspore/ccsrc/debug/debugger/tensor_summary.cc b/mindspore/ccsrc/debug/debugger/tensor_summary.cc index d46c272054f..7169ea1f454 100644 --- a/mindspore/ccsrc/debug/debugger/tensor_summary.cc +++ b/mindspore/ccsrc/debug/debugger/tensor_summary.cc @@ -28,9 +28,7 @@ #include "base/float16.h" #endif -#ifdef ONLINE_DBG_MODE namespace mindspore { -#endif using CONDITION_TYPE = DebugServices::CONDITION_TYPE; RangeCountCalculator::RangeCountCalculator() @@ -437,6 +435,4 @@ template class TensorSummary; template class TensorSummary; template class TensorSummary; template class TensorSummary; -#ifdef ONLINE_DBG_MODE } // namespace mindspore -#endif diff --git a/mindspore/ccsrc/debug/debugger/tensor_summary.h b/mindspore/ccsrc/debug/debugger/tensor_summary.h index 72f3b4de726..121af9172b1 100644 --- a/mindspore/ccsrc/debug/debugger/tensor_summary.h +++ b/mindspore/ccsrc/debug/debugger/tensor_summary.h @@ -24,9 +24,7 @@ #include "utils/hash_map.h" #include "debug/debug_services.h" -#ifdef ONLINE_DBG_MODE namespace mindspore { -#endif class RangeCountCalculator { public: RangeCountCalculator(); @@ -164,7 +162,5 @@ class TensorSummary : public ITensorSummary { void TensorStatisticsSingleThread(); void InitCalculators(const std::vector &); }; -#ifdef ONLINE_DBG_MODE } // namespace mindspore -#endif #endif // MINDSPORE_TENSOR_SUMMARY_H diff --git a/mindspore/ccsrc/debug/tensor_data.h b/mindspore/ccsrc/debug/tensor_data.h index bd5830ee376..1f4b434b709 100644 --- a/mindspore/ccsrc/debug/tensor_data.h +++ b/mindspore/ccsrc/debug/tensor_data.h @@ -25,10 +25,7 @@ #include "ir/tensor.h" #endif -#ifdef ONLINE_DBG_MODE namespace mindspore { -#endif - namespace MsTypeId { typedef enum MsTypeId : unsigned int { kTypeUnknown = 0, @@ -444,7 +441,5 @@ class TensorData { mindspore::tensor::TensorPtr tensor_ptr_{nullptr}; #endif }; -#ifdef ONLINE_DBG_MODE } // namespace mindspore -#endif #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ diff --git a/mindspore/ccsrc/debug/tensor_load.h b/mindspore/ccsrc/debug/tensor_load.h index 6d145ee6565..49afcb83fc3 100644 --- a/mindspore/ccsrc/debug/tensor_load.h +++ b/mindspore/ccsrc/debug/tensor_load.h @@ -28,8 +28,8 @@ #include "debug/tensor_data.h" #ifdef ONLINE_DBG_MODE #include "debug/data_dump/dump_json_parser.h" -namespace mindspore { #endif +namespace mindspore { class TensorLoader { public: #ifndef __APPLE__ @@ -287,7 +287,5 @@ class TensorLoader { std::deque cache_evict_queue_; std::condition_variable evict_cond; }; -#ifdef ONLINE_DBG_MODE } // namespace mindspore -#endif #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ diff --git a/mindspore/ccsrc/debug/utils.cc b/mindspore/ccsrc/debug/utils.cc new file mode 100644 index 00000000000..4ca9e779599 --- /dev/null +++ b/mindspore/ccsrc/debug/utils.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "debug/utils.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +bool CheckStoull(uint64_t *const output_digit, const std::string &input_str) { + try { + *output_digit = std::stoull(input_str); + } catch (const std::out_of_range &oor) { + MS_LOG(ERROR) << "Out of Range error: " << oor.what() << " when parse " << input_str; + return false; + } catch (const std::invalid_argument &ia) { + MS_LOG(ERROR) << "Invalid argument: " << ia.what() << " when parse " << input_str; + return false; + } + return true; +} + +bool CheckStoul(size_t *const output_digit, const std::string &input_str) { + try { + *output_digit = std::stoul(input_str); + } catch (const std::out_of_range &oor) { + MS_LOG(ERROR) << "Out of Range error: " << oor.what() << " when parse " << input_str; + return false; + } catch (const std::invalid_argument &ia) { + MS_LOG(ERROR) << "Invalid argument: " << ia.what() << " when parse " << input_str; + return false; + } + return true; +} + +bool CheckStoi(int64_t *const output_digit, const std::string &input_str) { + try { + *output_digit = std::stoi(input_str); + } catch (const std::out_of_range &oor) { + MS_LOG(ERROR) << "Out of Range error: " << oor.what() << " when parse " << input_str; + return false; + } catch (const std::invalid_argument &ia) { + MS_LOG(ERROR) << "Invalid argument: " << ia.what() << " when parse " << input_str; + return false; + } + return true; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/utils.h b/mindspore/ccsrc/debug/utils.h new file mode 100644 index 00000000000..b532dff1c7a --- /dev/null +++ b/mindspore/ccsrc/debug/utils.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_UTILS_H +#define MINDSPORE_UTILS_H + +#include + +namespace mindspore { +bool CheckStoull(uint64_t *const output_digit, const std::string &input_str); + +bool CheckStoul(size_t *const output_digit, const std::string &input_str); + +bool CheckStoi(int64_t *const output_digit, const std::string &input_str); +} // namespace mindspore + +#endif // MINDSPORE_UTILS_H diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index b738c81c4ca..69939e4fa1f 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -112,6 +112,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} # dont remove the 4 lines above "../../../mindspore/ccsrc/debug/data_dump/dump_json_parser.cc" "../../../mindspore/ccsrc/debug/common.cc" + "../../../mindspore/ccsrc/debug/utils.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.cc" "../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc" "../../../mindspore/ccsrc/runtime/device/memory_manager.cc"