From 18d1e8795aaaa40e0416c5b25cbfe09a7cb9ddd9 Mon Sep 17 00:00:00 2001 From: Harshvardhan Gupta Date: Fri, 11 Sep 2020 12:10:21 -0400 Subject: [PATCH] extend wp support to all types --- mindspore/ccsrc/debug/debug_services.cc | 88 ++++++++++++++++++++----- mindspore/ccsrc/debug/debug_services.h | 21 +++--- 2 files changed, 81 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index 449055a0b5..04b9925c92 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -58,13 +58,14 @@ void DebugServices::RemoveWatchpoint(unsigned int id) { watchpoint_table.erase(id); } -DebugServices::tensor_stats DebugServices::SummarizeTensor(const float *start, unsigned int n, bool need_min_max, +template +DebugServices::tensor_stats DebugServices::SummarizeTensor(const T *start, unsigned int n, bool need_min_max, bool need_mean_sd) { tensor_stats stats; for (unsigned int i = 0; i < n; ++i) { - float val = start[i]; - stats.has_nan = stats.has_nan || isnan(val); - stats.has_inf = stats.has_inf || isinf(val); + auto val = static_cast(start[i]); + stats.has_nan = stats.has_nan || std::isnan(val); + stats.has_inf = stats.has_inf || std::isinf(val); if (stats.has_inf && stats.has_nan) { // other statistics don't make sense in this case break; @@ -76,9 +77,7 @@ DebugServices::tensor_stats DebugServices::SummarizeTensor(const float *start, u } if (need_mean_sd) { - // for mean and sd calculation see - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm - float delta = val - stats.mean; + double delta = val - stats.mean; stats.mean += delta / (i + 1); stats.m2 += delta * (val - stats.mean); } @@ -109,13 +108,7 @@ void DebugServices::CheckWatchpoints(std::vector *name, std::vector bool inf_nan_enabled = false; for (auto w_table_item : watchpoint_table) { auto wp = std::get<1>(w_table_item); - - // if (!wp.conditions.condition_list[IS_OVERFLOW].enabled) { - if (wp.condition.type != IS_OVERFLOW) { - // only overflow condition supports all data types - if (tensor_dtype != kNumberTypeFloat && tensor_dtype != kNumberTypeFloat32) continue; - } - + if (wp.condition.type != IS_OVERFLOW && tensor_dtype == kNumberTypeBool) continue; if (wp.IsNodeIncluded(tensor_name_no_slot)) { min_max_enabled |= wp.min_max_enabled(); mean_sd_enabled |= wp.mean_sd_enabled(); @@ -124,11 +117,70 @@ void DebugServices::CheckWatchpoints(std::vector *name, std::vector } } tensor_stats stats; - + uint num_elements = tensor_ptr->DataSize(); if (min_max_enabled || mean_sd_enabled || inf_nan_enabled) { - auto *start_addr = reinterpret_cast(tensor_ptr->data_c()); - unsigned int num_elements = (tensor_ptr->data().nbytes()) / sizeof(float); - stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + switch (tensor_dtype) { + case kNumberTypeUInt8: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeInt8: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeUInt16: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeInt16: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeUInt32: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeInt32: + case kNumberTypeInt: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeUInt64: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeInt64: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeFloat16: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeFloat32: + case kNumberTypeFloat: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + case kNumberTypeFloat64: { + auto start_addr = reinterpret_cast(tensor_ptr->data_c()); + stats = SummarizeTensor(start_addr, num_elements, min_max_enabled, mean_sd_enabled); + break; + } + default: + MS_LOG(INFO) << "Unsupported tensor type"; + break; + } } for (auto &it : watchpoints_to_check_table) { diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h index 4ae1d51dd0..1dab625632 100644 --- a/mindspore/ccsrc/debug/debug_services.h +++ b/mindspore/ccsrc/debug/debug_services.h @@ -93,26 +93,26 @@ class DebugServices { } watchpoint_t; struct tensor_stats { - float min = std::numeric_limits::max(); - float max = std::numeric_limits::lowest(); + double min = std::numeric_limits::max(); + double max = std::numeric_limits::lowest(); bool has_inf = false; bool has_nan = false; unsigned int n = 0; - float mean = 0.0; - float m2 = 0.0; + double mean = 0.0; + double m2 = 0.0; - float statLookup(CONDITION_TYPE type) const { + double statLookup(CONDITION_TYPE type) const { if (type == MAX_GT || type == MAX_LT) return max; if (type == MIN_GT || type == MIN_LT) return min; if (type == MAX_MIN_GT || type == MAX_MIN_LT) return (max - min); if (type == MEAN_GT || type == MEAN_LT) return mean; if (type == SD_GT || type == SD_LT) return getStandardDeviation(); - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); } - float getMean() const { return mean; } + double getMean() const { return mean; } - float getVariance() const { + double getVariance() const { if (n > 1) { return m2 / (n - 1); } else { @@ -120,7 +120,7 @@ class DebugServices { } } - float getStandardDeviation() const { return sqrt(getVariance()); } + double getStandardDeviation() const { return sqrt(getVariance()); } }; void AddWatchpoint(unsigned int id, unsigned int watch_condition, float parameter, @@ -152,7 +152,8 @@ class DebugServices { TensorLoader *tensor_loader_; - static tensor_stats SummarizeTensor(const float *start, unsigned int n, bool need_min_max, bool need_mean_sd); + template + static tensor_stats SummarizeTensor(const T *start, unsigned int n, bool need_min_max, bool need_mean_sd); }; } // namespace mindspore