forked from mindspore-Ecosystem/mindspore
!15981 Address ci alarm in master
From: @tina_mengting_zhang Reviewed-by: @john_tzanakakis,@robingrosman Signed-off-by: @john_tzanakakis
This commit is contained in:
commit
32c8a1529e
|
@ -29,11 +29,7 @@
|
||||||
#ifdef ONLINE_DBG_MODE
|
#ifdef ONLINE_DBG_MODE
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
#endif
|
#endif
|
||||||
DebugServices::DebugServices() {
|
DebugServices::DebugServices() { tensor_loader_ = std::make_shared<TensorLoader>(); }
|
||||||
tensor_loader_ = new TensorLoader();
|
|
||||||
uint32_t iter_num = -1;
|
|
||||||
tensor_loader_->set_iter_num(iter_num);
|
|
||||||
}
|
|
||||||
|
|
||||||
DebugServices::DebugServices(const DebugServices &other) {
|
DebugServices::DebugServices(const DebugServices &other) {
|
||||||
tensor_loader_ = other.tensor_loader_;
|
tensor_loader_ = other.tensor_loader_;
|
||||||
|
@ -48,8 +44,6 @@ DebugServices &DebugServices::operator=(const DebugServices &other) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
DebugServices::~DebugServices() { delete tensor_loader_; }
|
|
||||||
|
|
||||||
void DebugServices::AddWatchpoint(
|
void DebugServices::AddWatchpoint(
|
||||||
unsigned int id, unsigned int watch_condition, float parameter,
|
unsigned int id, unsigned int watch_condition, float parameter,
|
||||||
const std::vector<std::tuple<std::string, bool>> &check_node_list, const std::vector<parameter_t> ¶meter_list,
|
const std::vector<std::tuple<std::string, bool>> &check_node_list, const std::vector<parameter_t> ¶meter_list,
|
||||||
|
@ -77,8 +71,9 @@ void DebugServices::RemoveWatchpoint(unsigned int id) {
|
||||||
watchpoint_table.erase(id);
|
watchpoint_table.erase(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ITensorSummary> GetSummaryPtr(const std::shared_ptr<TensorData> &tensor, void *previous_tensor_ptr,
|
std::unique_ptr<ITensorSummary> GetSummaryPtr(const std::shared_ptr<TensorData> &tensor,
|
||||||
uint32_t num_elements, int tensor_dtype) {
|
void *const previous_tensor_ptr, uint32_t num_elements,
|
||||||
|
int tensor_dtype) {
|
||||||
switch (tensor_dtype) {
|
switch (tensor_dtype) {
|
||||||
case DbgDataType::DT_UINT8: {
|
case DbgDataType::DT_UINT8: {
|
||||||
return std::make_unique<TensorSummary<uint8_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements);
|
return std::make_unique<TensorSummary<uint8_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements);
|
||||||
|
@ -156,8 +151,8 @@ void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bo
|
||||||
|
|
||||||
void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
|
void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
|
||||||
const std::string &tensor_name, const std::string &tensor_name_no_slot,
|
const std::string &tensor_name, const std::string &tensor_name_no_slot,
|
||||||
bool *previous_iter_tensor_needed, std::string *qualified_tensor_name,
|
bool *previous_iter_tensor_needed, std::string *const qualified_tensor_name,
|
||||||
std::vector<watchpoint_t> *watchpoints_to_check) {
|
std::vector<watchpoint_t> *const watchpoints_to_check) {
|
||||||
for (auto w_table_item : watchpoint_table) {
|
for (auto w_table_item : watchpoint_table) {
|
||||||
auto wp = std::get<1>(w_table_item);
|
auto wp = std::get<1>(w_table_item);
|
||||||
// check ONLY init conditions on initial suspended state.
|
// check ONLY init conditions on initial suspended state.
|
||||||
|
@ -191,10 +186,11 @@ void DebugServices::AddAnalyzedTensorToCache(const bool recheck, const unsigned
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DebugServices::CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot,
|
void DebugServices::CheckWatchpoints(std::vector<std::string> *const name, std::vector<std::string> *const slot,
|
||||||
std::vector<int> *condition, std::vector<unsigned int> *watchpoint_id,
|
std::vector<int> *const condition, std::vector<unsigned int> *const watchpoint_id,
|
||||||
std::vector<std::vector<parameter_t>> *parameters,
|
std::vector<std::vector<parameter_t>> *const parameters,
|
||||||
std::vector<int32_t> *error_codes, const std::vector<std::string> &op_overflows,
|
std::vector<int32_t> *const error_codes,
|
||||||
|
const std::vector<std::string> &op_overflows,
|
||||||
const std::vector<std::string> &async_file_pool,
|
const std::vector<std::string> &async_file_pool,
|
||||||
std::vector<std::shared_ptr<TensorData>> *tensor_list, const bool init_dbg_suspend,
|
std::vector<std::shared_ptr<TensorData>> *tensor_list, const bool init_dbg_suspend,
|
||||||
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id,
|
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id,
|
||||||
|
@ -818,9 +814,10 @@ std::vector<std::shared_ptr<TensorData>> DebugServices::ReadNeededDumpedTensors(
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void DebugServices::ReadNodesTensors(std::vector<std::string> name, std::vector<std::string> *ret_name,
|
void DebugServices::ReadNodesTensors(const std::vector<std::string> &name, std::vector<std::string> *const ret_name,
|
||||||
std::vector<char *> *data_ptr, std::vector<ssize_t> *data_size,
|
std::vector<char *> *const data_ptr, std::vector<ssize_t> *const data_size,
|
||||||
std::vector<unsigned int> *dtype, std::vector<std::vector<int64_t>> *shape) {
|
std::vector<unsigned int> *const dtype,
|
||||||
|
std::vector<std::vector<int64_t>> *const shape) {
|
||||||
std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
|
std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
|
||||||
tensor_loader_->SearchTensors(name, &result_list);
|
tensor_loader_->SearchTensors(name, &result_list);
|
||||||
|
|
||||||
|
@ -929,10 +926,10 @@ std::vector<std::shared_ptr<TensorData>> DebugServices::GetNodeTensor(const CNod
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool DebugServices::TensorExistsInCurrent(std::string tensor_name) {
|
bool DebugServices::TensorExistsInCurrent(const std::string &tensor_name) {
|
||||||
return tensor_loader_->TensorExistsInCurrent(tensor_name);
|
return tensor_loader_->TensorExistsInCurrent(tensor_name);
|
||||||
}
|
}
|
||||||
void DebugServices::MoveTensorCurrentToPrev(std::string tensor_name) {
|
void DebugServices::MoveTensorCurrentToPrev(const std::string &tensor_name) {
|
||||||
tensor_loader_->MoveTensorCurrentToPrev(tensor_name);
|
tensor_loader_->MoveTensorCurrentToPrev(tensor_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ class DebugServices {
|
||||||
|
|
||||||
DebugServices &operator=(const DebugServices &other);
|
DebugServices &operator=(const DebugServices &other);
|
||||||
|
|
||||||
~DebugServices();
|
~DebugServices() = default;
|
||||||
|
|
||||||
enum CONDITION_TYPE {
|
enum CONDITION_TYPE {
|
||||||
HAS_NAN,
|
HAS_NAN,
|
||||||
|
@ -121,7 +121,7 @@ class DebugServices {
|
||||||
std::vector<parameter_t> parameter_list;
|
std::vector<parameter_t> parameter_list;
|
||||||
size_t location = 0;
|
size_t location = 0;
|
||||||
|
|
||||||
std::string FindQualifiedTensorName(const std::string &tensor_name) {
|
std::string FindQualifiedTensorName(const std::string &tensor_name) const {
|
||||||
std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':'));
|
std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':'));
|
||||||
for (auto check_node : check_node_list) {
|
for (auto check_node : check_node_list) {
|
||||||
std::string w_name = std::get<0>(check_node);
|
std::string w_name = std::get<0>(check_node);
|
||||||
|
@ -135,17 +135,17 @@ class DebugServices {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_gt_wp() {
|
bool is_gt_wp() const {
|
||||||
return condition.type == MAX_GT || condition.type == MIN_GT || condition.type == MEAN_GT ||
|
return condition.type == MAX_GT || condition.type == MIN_GT || condition.type == MEAN_GT ||
|
||||||
condition.type == SD_GT || condition.type == MAX_MIN_GT;
|
condition.type == SD_GT || condition.type == MAX_MIN_GT;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_lt_wp() {
|
bool is_lt_wp() const {
|
||||||
return condition.type == MAX_LT || condition.type == MIN_LT || condition.type == MEAN_LT ||
|
return condition.type == MAX_LT || condition.type == MIN_LT || condition.type == MEAN_LT ||
|
||||||
condition.type == SD_LT || condition.type == MAX_MIN_LT;
|
condition.type == SD_LT || condition.type == MAX_MIN_LT;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool min_max_enabled() {
|
bool min_max_enabled() const {
|
||||||
return condition.type == MAX_LT || condition.type == MAX_GT || condition.type == MIN_LT ||
|
return condition.type == MAX_LT || condition.type == MAX_GT || condition.type == MIN_LT ||
|
||||||
condition.type == MIN_GT || condition.type == MAX_MIN_LT || condition.type == MAX_MIN_GT ||
|
condition.type == MIN_GT || condition.type == MAX_MIN_LT || condition.type == MAX_MIN_GT ||
|
||||||
(condition.type == INIT && (!parameter_list[1].disabled || !parameter_list[2].disabled)) ||
|
(condition.type == INIT && (!parameter_list[1].disabled || !parameter_list[2].disabled)) ||
|
||||||
|
@ -153,7 +153,7 @@ class DebugServices {
|
||||||
(condition.type == TOO_SMALL && (!parameter_list[1].disabled || !parameter_list[2].disabled));
|
(condition.type == TOO_SMALL && (!parameter_list[1].disabled || !parameter_list[2].disabled));
|
||||||
}
|
}
|
||||||
// inf or nan related condition set
|
// inf or nan related condition set
|
||||||
bool inf_nan_enabled() {
|
bool inf_nan_enabled() const {
|
||||||
return condition.type == HAS_INF || condition.type == HAS_NAN || condition.type == GENERAL_OVERFLOW;
|
return condition.type == HAS_INF || condition.type == HAS_NAN || condition.type == GENERAL_OVERFLOW;
|
||||||
}
|
}
|
||||||
// mean or sd related condition set
|
// mean or sd related condition set
|
||||||
|
@ -166,7 +166,7 @@ class DebugServices {
|
||||||
return (condition.type == TOO_LARGE && !parameter_list[0].disabled) ||
|
return (condition.type == TOO_LARGE && !parameter_list[0].disabled) ||
|
||||||
(condition.type == TOO_SMALL && !parameter_list[0].disabled);
|
(condition.type == TOO_SMALL && !parameter_list[0].disabled);
|
||||||
}
|
}
|
||||||
bool zero_percentage_enabled() { return condition.type == ALL_ZERO || condition.type == INIT; }
|
bool zero_percentage_enabled() const { return condition.type == ALL_ZERO || condition.type == INIT; }
|
||||||
|
|
||||||
bool tensor_update_ratio_mean_enabled() const {
|
bool tensor_update_ratio_mean_enabled() const {
|
||||||
return condition.type == CHANGE_TOO_LARGE || condition.type == CHANGE_TOO_SMALL;
|
return condition.type == CHANGE_TOO_LARGE || condition.type == CHANGE_TOO_SMALL;
|
||||||
|
@ -191,9 +191,9 @@ class DebugServices {
|
||||||
void RemoveWatchpoint(unsigned int id);
|
void RemoveWatchpoint(unsigned int id);
|
||||||
|
|
||||||
void CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, std::vector<int> *condition,
|
void CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, std::vector<int> *condition,
|
||||||
std::vector<unsigned int> *watchpoint_id, std::vector<std::vector<parameter_t>> *parameters,
|
std::vector<unsigned int> *const watchpoint_id,
|
||||||
std::vector<int32_t> *error_code, const std::vector<std::string> &op_overflows,
|
std::vector<std::vector<parameter_t>> *parameters, std::vector<int32_t> *error_code,
|
||||||
const std::vector<std::string> &async_file_pool,
|
const std::vector<std::string> &op_overflows, const std::vector<std::string> &async_file_pool,
|
||||||
std::vector<std::shared_ptr<TensorData>> *tensor_list, bool init_dbg_suspend,
|
std::vector<std::shared_ptr<TensorData>> *tensor_list, bool init_dbg_suspend,
|
||||||
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id = nullptr,
|
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id = nullptr,
|
||||||
std::vector<unsigned int> *root_graph_id = nullptr);
|
std::vector<unsigned int> *root_graph_id = nullptr);
|
||||||
|
@ -243,9 +243,9 @@ class DebugServices {
|
||||||
const std::vector<std::string> &async_file_pool,
|
const std::vector<std::string> &async_file_pool,
|
||||||
std::vector<std::shared_ptr<TensorData>> *tensor_list);
|
std::vector<std::shared_ptr<TensorData>> *tensor_list);
|
||||||
#endif
|
#endif
|
||||||
void ReadNodesTensors(std::vector<std::string> name, std::vector<std::string> *ret_name,
|
void ReadNodesTensors(const std::vector<std::string> &name, std::vector<std::string> *ret_name,
|
||||||
std::vector<char *> *data_ptr, std::vector<ssize_t> *data_size,
|
std::vector<char *> *data_ptr, std::vector<ssize_t> *data_size,
|
||||||
std::vector<unsigned int> *dtype, std::vector<std::vector<int64_t>> *shape);
|
std::vector<unsigned int> *dtype, std::vector<std::vector<int64_t>> *const shape);
|
||||||
#ifdef ONLINE_DBG_MODE
|
#ifdef ONLINE_DBG_MODE
|
||||||
bool IsWatchPoint(const std::string &kernel_name, const CNodePtr &kernel = nullptr) const;
|
bool IsWatchPoint(const std::string &kernel_name, const CNodePtr &kernel = nullptr) const;
|
||||||
|
|
||||||
|
@ -282,9 +282,9 @@ class DebugServices {
|
||||||
std::vector<std::shared_ptr<TensorData>> GetNodeTensor(const CNodePtr &kernel);
|
std::vector<std::shared_ptr<TensorData>> GetNodeTensor(const CNodePtr &kernel);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool TensorExistsInCurrent(std::string tensor_name);
|
bool TensorExistsInCurrent(const std::string &tensor_name);
|
||||||
|
|
||||||
void MoveTensorCurrentToPrev(std::string tensor_name);
|
void MoveTensorCurrentToPrev(const std::string &tensor_name);
|
||||||
|
|
||||||
void SetNetName(std::string net_name);
|
void SetNetName(std::string net_name);
|
||||||
|
|
||||||
|
@ -308,7 +308,7 @@ class DebugServices {
|
||||||
std::string dump_dir;
|
std::string dump_dir;
|
||||||
bool is_sync_mode;
|
bool is_sync_mode;
|
||||||
|
|
||||||
TensorLoader *tensor_loader_;
|
std::shared_ptr<TensorLoader> tensor_loader_;
|
||||||
};
|
};
|
||||||
#ifdef ONLINE_DBG_MODE
|
#ifdef ONLINE_DBG_MODE
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -191,7 +191,7 @@ void Debugger::SetOpOverflowBinPath(uint32_t graph_id) {
|
||||||
d = opendir(overflow_bin_path.c_str());
|
d = opendir(overflow_bin_path.c_str());
|
||||||
if (d != nullptr) {
|
if (d != nullptr) {
|
||||||
struct dirent *dir;
|
struct dirent *dir;
|
||||||
while ((dir = readdir(d)) != NULL) {
|
while ((dir = readdir(d)) != nullptr) {
|
||||||
if (dir->d_type == DT_REG) {
|
if (dir->d_type == DT_REG) {
|
||||||
std::string file_path = overflow_bin_path;
|
std::string file_path = overflow_bin_path;
|
||||||
file_path.append(dir->d_name);
|
file_path.append(dir->d_name);
|
||||||
|
@ -225,7 +225,7 @@ void Debugger::CheckDatasetSinkMode() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Debugger::CheckDebuggerDumpEnabled() {
|
bool Debugger::CheckDebuggerDumpEnabled() const {
|
||||||
// see if dump is enabled
|
// see if dump is enabled
|
||||||
if (device_target_ == kGPUDevice) {
|
if (device_target_ == kGPUDevice) {
|
||||||
return device::KernelRuntime::DumpDataEnabled();
|
return device::KernelRuntime::DumpDataEnabled();
|
||||||
|
@ -233,7 +233,7 @@ bool Debugger::CheckDebuggerDumpEnabled() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Debugger::CheckDebuggerEnabled() {
|
bool Debugger::CheckDebuggerEnabled() const {
|
||||||
// get env variables to configure debugger
|
// get env variables to configure debugger
|
||||||
const char *env_enable_char = std::getenv("ENABLE_MS_DEBUGGER");
|
const char *env_enable_char = std::getenv("ENABLE_MS_DEBUGGER");
|
||||||
if (env_enable_char != nullptr) {
|
if (env_enable_char != nullptr) {
|
||||||
|
@ -259,7 +259,7 @@ void Debugger::CheckDebuggerEnabledParam() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Debugger::CheckDebuggerPartialMemoryEnabled() {
|
bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
|
||||||
const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
|
const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
|
||||||
if (env_partial_mem_str != nullptr) {
|
if (env_partial_mem_str != nullptr) {
|
||||||
MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
|
MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
|
||||||
|
@ -270,7 +270,7 @@ bool Debugger::CheckDebuggerPartialMemoryEnabled() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Debugger::DebuggerBackendEnabled() { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
|
bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
|
||||||
|
|
||||||
void Debugger::Reset() {
|
void Debugger::Reset() {
|
||||||
// access lock for public method
|
// access lock for public method
|
||||||
|
@ -326,7 +326,7 @@ void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
|
||||||
LoadParametersAndConst();
|
LoadParametersAndConst();
|
||||||
// revert graph ptr to original value
|
// revert graph ptr to original value
|
||||||
graph_ptr_ = dbg_graph_ptr;
|
graph_ptr_ = dbg_graph_ptr;
|
||||||
SendMultiGraphsAndSuspend(graph_proto_list_, graph_sum);
|
SendMultiGraphsAndSuspend(graph_proto_list_);
|
||||||
graph_proto_list_.clear();
|
graph_proto_list_.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -367,7 +367,7 @@ void Debugger::PostExecute(const KernelGraphPtr &graph_ptr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) {
|
bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
|
||||||
if (debugger_enabled_ && !is_dataset_graph_) {
|
if (debugger_enabled_ && !is_dataset_graph_) {
|
||||||
auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
|
auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
|
||||||
// if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
|
// if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
|
||||||
|
@ -484,7 +484,7 @@ void Debugger::CheckDatasetGraph() {
|
||||||
|
|
||||||
GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
|
GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
|
||||||
// convert kernel graph to debugger modelproto
|
// convert kernel graph to debugger modelproto
|
||||||
ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
|
ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
|
||||||
return model.graph();
|
return model.graph();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -542,7 +542,7 @@ bool Debugger::SendMetadata(bool version_check) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list, uint32_t graph_sum) {
|
void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
|
||||||
if (!SendMetadata(true)) {
|
if (!SendMetadata(true)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -733,7 +733,7 @@ void Debugger::ProcessKViewCMD(const EventReply &reply) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddTensorProtoInfo(TensorProto *tensor_item, TensorProto tensor) {
|
void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
|
||||||
tensor_item->set_node_name(tensor.node_name());
|
tensor_item->set_node_name(tensor.node_name());
|
||||||
tensor_item->set_slot(tensor.slot());
|
tensor_item->set_slot(tensor.slot());
|
||||||
tensor_item->set_iter(tensor.iter());
|
tensor_item->set_iter(tensor.iter());
|
||||||
|
@ -1008,9 +1008,9 @@ std::string GetTensorFullName(const TensorProto &tensor) {
|
||||||
|
|
||||||
bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
|
bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
|
||||||
|
|
||||||
bool Debugger::partial_memory() { return partial_memory_; }
|
bool Debugger::partial_memory() const { return partial_memory_; }
|
||||||
|
|
||||||
void Debugger::SetCurNode(std::string cur_name) {
|
void Debugger::SetCurNode(const std::string &cur_name) {
|
||||||
// access lock for public method
|
// access lock for public method
|
||||||
std::lock_guard<std::mutex> a_lock(access_lock_);
|
std::lock_guard<std::mutex> a_lock(access_lock_);
|
||||||
cur_name_ = cur_name;
|
cur_name_ = cur_name;
|
||||||
|
@ -1046,7 +1046,7 @@ std::vector<std::string> Debugger::CheckOpOverflow() {
|
||||||
MS_LOG(INFO) << "processing bin file path " << overflow_bin_path << ", graph id " << graph_id;
|
MS_LOG(INFO) << "processing bin file path " << overflow_bin_path << ", graph id " << graph_id;
|
||||||
if (d != nullptr) {
|
if (d != nullptr) {
|
||||||
struct dirent *dir = nullptr;
|
struct dirent *dir = nullptr;
|
||||||
while ((dir = readdir(d)) != NULL) {
|
while ((dir = readdir(d)) != nullptr) {
|
||||||
if (dir->d_type == DT_REG) {
|
if (dir->d_type == DT_REG) {
|
||||||
std::string file_path = overflow_bin_path;
|
std::string file_path = overflow_bin_path;
|
||||||
file_path.append(dir->d_name);
|
file_path.append(dir->d_name);
|
||||||
|
@ -1126,7 +1126,7 @@ bool Debugger::CheckPort(const char *port) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Debugger::CheckIp(const char *host) {
|
bool Debugger::CheckIp(const char *host) const {
|
||||||
std::regex reg_ip(
|
std::regex reg_ip(
|
||||||
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
|
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
|
||||||
"[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
|
"[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
|
||||||
|
@ -1137,7 +1137,7 @@ bool Debugger::CheckIp(const char *host) {
|
||||||
return std::regex_match(host_str, smat, reg_ip);
|
return std::regex_match(host_str, smat, reg_ip);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t Debugger::GetFirstRunGraphId() { return rungraph_id_list_.front(); }
|
uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
|
||||||
|
|
||||||
void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
|
void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
@ -1201,7 +1201,6 @@ void Debugger::LoadGraphOutputs() {
|
||||||
int exec_order = 1;
|
int exec_order = 1;
|
||||||
for (const auto &node : apply_kernels) {
|
for (const auto &node : apply_kernels) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
|
||||||
std::string kernel_name = node->fullname_with_scope();
|
std::string kernel_name = node->fullname_with_scope();
|
||||||
auto output_size = AnfAlgo::GetOutputTensorNum(node);
|
auto output_size = AnfAlgo::GetOutputTensorNum(node);
|
||||||
if (partial_memory_) {
|
if (partial_memory_) {
|
||||||
|
@ -1247,7 +1246,7 @@ void Debugger::ClearCurrentData() {
|
||||||
if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
|
if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()))
|
||||||
debug_services_->EmptyCurrentTensor();
|
debug_services_->EmptyCurrentTensor();
|
||||||
}
|
}
|
||||||
bool Debugger::TensorExistsInCurrent(std::string tensor_name) {
|
bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
|
||||||
return debug_services_->TensorExistsInCurrent(tensor_name);
|
return debug_services_->TensorExistsInCurrent(tensor_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
// don't need a graph_ptr because it is saved during pre_execute
|
// don't need a graph_ptr because it is saved during pre_execute
|
||||||
void PostExecute(const KernelGraphPtr &graph_ptr = nullptr);
|
void PostExecute(const KernelGraphPtr &graph_ptr = nullptr);
|
||||||
|
|
||||||
bool ReadNodeDataRequired(const CNodePtr &kernel);
|
bool ReadNodeDataRequired(const CNodePtr &kernel) const;
|
||||||
|
|
||||||
void PostExecuteNode(const CNodePtr &kernel, bool last_kernel);
|
void PostExecuteNode(const CNodePtr &kernel, bool last_kernel);
|
||||||
|
|
||||||
|
@ -107,9 +107,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
|
|
||||||
bool debugger_enabled() const;
|
bool debugger_enabled() const;
|
||||||
|
|
||||||
bool partial_memory();
|
bool partial_memory() const;
|
||||||
|
|
||||||
void SetCurNode(std::string cur_name);
|
void SetCurNode(const std::string &cur_name);
|
||||||
|
|
||||||
std::string run_level() const;
|
std::string run_level() const;
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
void SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping);
|
void SetStreamTaskToOpnameMap(const std::map<std::pair<uint32_t, uint32_t>, std::string> &mapping);
|
||||||
|
|
||||||
// check if any feature that uses the debugger backend is enabled
|
// check if any feature that uses the debugger backend is enabled
|
||||||
bool DebuggerBackendEnabled();
|
bool DebuggerBackendEnabled() const;
|
||||||
|
|
||||||
void SetTrainingDone(bool training_done);
|
void SetTrainingDone(bool training_done);
|
||||||
|
|
||||||
|
@ -140,13 +140,13 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
|
|
||||||
void LoadGraphs(const KernelGraphPtr &graph_ptr);
|
void LoadGraphs(const KernelGraphPtr &graph_ptr);
|
||||||
|
|
||||||
uint32_t GetFirstRunGraphId();
|
uint32_t GetFirstRunGraphId() const;
|
||||||
|
|
||||||
void SetGraphPtr(const KernelGraphPtr &graph_ptr) { graph_ptr_ = graph_ptr; }
|
void SetGraphPtr(const KernelGraphPtr &graph_ptr) { graph_ptr_ = graph_ptr; }
|
||||||
|
|
||||||
std::list<KernelGraphPtr> GetGraphPtrList() { return graph_ptr_list_; }
|
std::list<KernelGraphPtr> GetGraphPtrList() { return graph_ptr_list_; }
|
||||||
|
|
||||||
bool TensorExistsInCurrent(std::string tensor_name);
|
bool TensorExistsInCurrent(const std::string &tensor_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// private constructor for singleton
|
// private constructor for singleton
|
||||||
|
@ -160,14 +160,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
void SetOpOverflowBinPath(uint32_t graph_id);
|
void SetOpOverflowBinPath(uint32_t graph_id);
|
||||||
|
|
||||||
// check if dump using debugger backend is enabled
|
// check if dump using debugger backend is enabled
|
||||||
bool CheckDebuggerDumpEnabled();
|
bool CheckDebuggerDumpEnabled() const;
|
||||||
|
|
||||||
// check if debugger enabled
|
// check if debugger enabled
|
||||||
bool CheckDebuggerEnabled();
|
bool CheckDebuggerEnabled() const;
|
||||||
|
|
||||||
void CheckDebuggerEnabledParam();
|
void CheckDebuggerEnabledParam();
|
||||||
|
|
||||||
bool CheckDebuggerPartialMemoryEnabled();
|
bool CheckDebuggerPartialMemoryEnabled() const;
|
||||||
|
|
||||||
// check and save graph pointer
|
// check and save graph pointer
|
||||||
void CheckGraphPtr(const KernelGraphPtr &graph_ptr);
|
void CheckGraphPtr(const KernelGraphPtr &graph_ptr);
|
||||||
|
@ -181,7 +181,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
// send graph and enter command wait loop
|
// send graph and enter command wait loop
|
||||||
void SendGraphAndSuspend(const GraphProto &graph_proto);
|
void SendGraphAndSuspend(const GraphProto &graph_proto);
|
||||||
|
|
||||||
void SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list, uint32_t graph_sum);
|
void SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list);
|
||||||
|
|
||||||
// wait for command and process command
|
// wait for command and process command
|
||||||
// send command request and process reply in a loop
|
// send command request and process reply in a loop
|
||||||
|
@ -223,7 +223,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
||||||
bool CheckPort(const char *port);
|
bool CheckPort(const char *port);
|
||||||
|
|
||||||
// Check if the IP is valid
|
// Check if the IP is valid
|
||||||
bool CheckIp(const char *host);
|
bool CheckIp(const char *host) const;
|
||||||
|
|
||||||
void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index);
|
void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index);
|
||||||
|
|
||||||
|
|
|
@ -336,7 +336,7 @@ debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr
|
||||||
return model_;
|
return model_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
|
void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
|
||||||
LocDebugDumpMode dump_location) {
|
LocDebugDumpMode dump_location) {
|
||||||
if (func_graph == nullptr || graph_proto == nullptr) {
|
if (func_graph == nullptr || graph_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -378,7 +378,7 @@ void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, deb
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
|
void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
|
||||||
std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location) {
|
std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location) {
|
||||||
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
|
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -402,8 +402,8 @@ void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugge
|
||||||
|
|
||||||
void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *apply_map_ptr,
|
std::map<AnfNodePtr, size_t> *apply_map_ptr,
|
||||||
std::map<AnfNodePtr, size_t> *const_map_ptr, debugger::GraphProto *graph_proto,
|
std::map<AnfNodePtr, size_t> *const_map_ptr,
|
||||||
LocDebugDumpMode dump_location) {
|
debugger::GraphProto *const graph_proto, LocDebugDumpMode dump_location) {
|
||||||
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
|
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
|
||||||
graph_proto == nullptr) {
|
graph_proto == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -48,14 +48,13 @@ class DebuggerProtoExporter {
|
||||||
void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto);
|
void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto);
|
||||||
void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto);
|
void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto);
|
||||||
void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto);
|
void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto);
|
||||||
|
void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
|
||||||
void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
|
|
||||||
LocDebugDumpMode dump_location = kDebugOff);
|
LocDebugDumpMode dump_location = kDebugOff);
|
||||||
void ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto);
|
void ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto);
|
||||||
void ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto,
|
void ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
|
||||||
std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location = kDebugOff);
|
std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location = kDebugOff);
|
||||||
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
|
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
|
||||||
std::map<AnfNodePtr, size_t> *const_map_ptr, debugger::GraphProto *graph_proto,
|
std::map<AnfNodePtr, size_t> *const_map_ptr, debugger::GraphProto *const graph_proto,
|
||||||
LocDebugDumpMode dump_location = kDebugOff);
|
LocDebugDumpMode dump_location = kDebugOff);
|
||||||
void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
|
void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
|
||||||
const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
|
const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
|
||||||
|
|
|
@ -58,7 +58,7 @@ void AllCloseCalculator::ProcessElement(double current, double previous) {
|
||||||
result = result && (std::abs(current - previous) <= (atol + rtol * std::abs(previous)));
|
result = result && (std::abs(current - previous) <= (atol + rtol * std::abs(previous)));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllCloseCalculator::IsAllClose() { return result; }
|
bool AllCloseCalculator::IsAllClose() const { return result; }
|
||||||
|
|
||||||
MeanCalculator::MeanCalculator() : mean(0.0), count(0) {}
|
MeanCalculator::MeanCalculator() : mean(0.0), count(0) {}
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ void MeanCalculator::ProcessElement(double value) {
|
||||||
mean += delta / count;
|
mean += delta / count;
|
||||||
}
|
}
|
||||||
|
|
||||||
double MeanCalculator::GetMean() { return mean; }
|
double MeanCalculator::GetMean() const { return mean; }
|
||||||
|
|
||||||
VarianceAndMeanCalculator::VarianceAndMeanCalculator() : mean(0.0), count(0), m2(0.0) {}
|
VarianceAndMeanCalculator::VarianceAndMeanCalculator() : mean(0.0), count(0), m2(0.0) {}
|
||||||
|
|
||||||
|
@ -79,9 +79,9 @@ void VarianceAndMeanCalculator::ProcessElement(double value) {
|
||||||
m2 += delta * (value - mean);
|
m2 += delta * (value - mean);
|
||||||
}
|
}
|
||||||
|
|
||||||
double VarianceAndMeanCalculator::GetMean() { return mean; }
|
double VarianceAndMeanCalculator::GetMean() const { return mean; }
|
||||||
|
|
||||||
double VarianceAndMeanCalculator::GetVariance() {
|
double VarianceAndMeanCalculator::GetVariance() const {
|
||||||
if (count > 1) {
|
if (count > 1) {
|
||||||
return m2 / (count - 1);
|
return m2 / (count - 1);
|
||||||
} else {
|
} else {
|
||||||
|
@ -92,7 +92,7 @@ double VarianceAndMeanCalculator::GetVariance() {
|
||||||
double VarianceAndMeanCalculator::GetStandardDeviation() { return sqrt(GetVariance()); }
|
double VarianceAndMeanCalculator::GetStandardDeviation() { return sqrt(GetVariance()); }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TensorSummary<T>::TensorSummary(void *current_tensor_ptr, void *previous_tensor_ptr, uint32_t num_elements)
|
TensorSummary<T>::TensorSummary(void *current_tensor_ptr, void *const previous_tensor_ptr, uint32_t num_elements)
|
||||||
: current_tensor_ptr(reinterpret_cast<T *>(current_tensor_ptr)),
|
: current_tensor_ptr(reinterpret_cast<T *>(current_tensor_ptr)),
|
||||||
prev_tensor_ptr(reinterpret_cast<T *>(previous_tensor_ptr)),
|
prev_tensor_ptr(reinterpret_cast<T *>(previous_tensor_ptr)),
|
||||||
num_elements(num_elements),
|
num_elements(num_elements),
|
||||||
|
|
|
@ -48,7 +48,7 @@ class AllCloseCalculator {
|
||||||
AllCloseCalculator();
|
AllCloseCalculator();
|
||||||
~AllCloseCalculator() = default;
|
~AllCloseCalculator() = default;
|
||||||
void ProcessElement(double current, double previous);
|
void ProcessElement(double current, double previous);
|
||||||
bool IsAllClose();
|
bool IsAllClose() const;
|
||||||
void set_atol(double value) { atol = value; }
|
void set_atol(double value) { atol = value; }
|
||||||
void set_rtol(double value) { rtol = value; }
|
void set_rtol(double value) { rtol = value; }
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class MeanCalculator {
|
||||||
MeanCalculator();
|
MeanCalculator();
|
||||||
~MeanCalculator() = default;
|
~MeanCalculator() = default;
|
||||||
void ProcessElement(double value);
|
void ProcessElement(double value);
|
||||||
double GetMean();
|
double GetMean() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
double mean;
|
double mean;
|
||||||
|
@ -76,8 +76,8 @@ class VarianceAndMeanCalculator {
|
||||||
~VarianceAndMeanCalculator() = default;
|
~VarianceAndMeanCalculator() = default;
|
||||||
void ProcessElement(double value);
|
void ProcessElement(double value);
|
||||||
double GetStandardDeviation();
|
double GetStandardDeviation();
|
||||||
double GetVariance();
|
double GetVariance() const;
|
||||||
double GetMean();
|
double GetMean() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double mean;
|
double mean;
|
||||||
|
|
|
@ -178,11 +178,11 @@ class TensorData {
|
||||||
|
|
||||||
~TensorData() {}
|
~TensorData() {}
|
||||||
|
|
||||||
std::string GetName() { return this->name; }
|
std::string GetName() const { return this->name; }
|
||||||
|
|
||||||
size_t GetSlot() { return this->slot; }
|
size_t GetSlot() const { return this->slot; }
|
||||||
|
|
||||||
int GetExecutionOrder() { return this->execution_order; }
|
int GetExecutionOrder() const { return this->execution_order; }
|
||||||
|
|
||||||
void SetExecutionOrder(int execution_order) { this->execution_order = execution_order; }
|
void SetExecutionOrder(int execution_order) { this->execution_order = execution_order; }
|
||||||
|
|
||||||
|
|
|
@ -47,12 +47,12 @@ class TensorLoader {
|
||||||
|
|
||||||
void SwapCurrentPrev() { tensor_list_map.swap(prev_tensor_list_map); }
|
void SwapCurrentPrev() { tensor_list_map.swap(prev_tensor_list_map); }
|
||||||
|
|
||||||
bool TensorExistsInCurrent(std::string tensor_name) {
|
bool TensorExistsInCurrent(std::string tensor_name) const {
|
||||||
return tensor_list_map.find(tensor_name) != tensor_list_map.end();
|
return tensor_list_map.find(tensor_name) != tensor_list_map.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
// only parameters will return true
|
// only parameters will return true
|
||||||
bool PrevTensorExistsInCurrent(std::string tensor_name) { return TensorExistsInCurrent(tensor_name + ":prev"); }
|
bool PrevTensorExistsInCurrent(std::string tensor_name) const { return TensorExistsInCurrent(tensor_name + ":prev"); }
|
||||||
|
|
||||||
void MoveParametersCurrentToPrev() {
|
void MoveParametersCurrentToPrev() {
|
||||||
MS_LOG(INFO) << "Moving parameters from current map to previous map";
|
MS_LOG(INFO) << "Moving parameters from current map to previous map";
|
||||||
|
@ -69,7 +69,7 @@ class TensorLoader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsPrevTensor(std::string tensor_name) {
|
bool IsPrevTensor(std::string tensor_name) const {
|
||||||
const std::string suffix = ":prev";
|
const std::string suffix = ":prev";
|
||||||
if (tensor_name.length() <= suffix.length()) return false;
|
if (tensor_name.length() <= suffix.length()) return false;
|
||||||
return std::equal(suffix.rbegin(), suffix.rend(), tensor_name.rbegin());
|
return std::equal(suffix.rbegin(), suffix.rend(), tensor_name.rbegin());
|
||||||
|
@ -100,13 +100,13 @@ class TensorLoader {
|
||||||
return tensor_list;
|
return tensor_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<TensorData> GetTensor(const std::string &tensor_name) {
|
std::shared_ptr<TensorData> GetTensor(const std::string &tensor_name) const {
|
||||||
auto iter = tensor_list_map.find(tensor_name);
|
auto iter = tensor_list_map.find(tensor_name);
|
||||||
if (iter != tensor_list_map.end()) return iter->second;
|
if (iter != tensor_list_map.end()) return iter->second;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t GetIterNum() { return iter_num; }
|
uint32_t GetIterNum() const { return iter_num; }
|
||||||
|
|
||||||
std::map<std::string, std::shared_ptr<TensorData>> GetTensorMap() { return tensor_list_map; }
|
std::map<std::string, std::shared_ptr<TensorData>> GetTensorMap() { return tensor_list_map; }
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue