diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc index 89589e10a37..582ff7f65b4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc @@ -23,6 +23,7 @@ #include #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" +#include "utils/ms_utils.h" namespace mindspore::graphkernel { namespace { @@ -37,33 +38,15 @@ bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std: [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) { return false; } - auto main_attrs = main_primitive->attrs(); auto node_attrs = node_primitive->attrs(); - std::vector exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"}; for (auto &attr : exclude_attrs) { main_attrs.erase(attr); node_attrs.erase(attr); } - - if (main_attrs.size() != node_attrs.size()) { - return false; - } - - auto all = std::all_of(main_attrs.begin(), main_attrs.end(), [&node_attrs](const auto &item) -> bool { - if (item.second == nullptr) { - return false; - } - auto iter = node_attrs.find(item.first); - if (iter == node_attrs.end()) { - return false; - } - return *item.second == *iter->second; - }); - return all; + return common::IsAttrsEqual(main_attrs, node_attrs); } - return *main->inputs()[0] == *node->inputs()[0]; } } // namespace diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_cache.h b/mindspore/ccsrc/pipeline/pynative/pynative_cache.h index 1a19612b3ec..455de93e800 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_cache.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_cache.h @@ -21,6 +21,7 @@ #include #include #include "utils/hash_map.h" +#include "utils/ms_utils.h" #include "ir/anf.h" namespace mindspore::pynative { @@ -36,26 +37,10 @@ struct AbsCacheKeyHasher { struct AbsCacheKeyEqual { bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const { - if (lk.prim_attrs_.size() != rk.prim_attrs_.size()) { - return false; - } if (lk.prim_name_ != rk.prim_name_) { return false; } - - auto all = std::all_of(lk.prim_attrs_.begin(), lk.prim_attrs_.end(), [&rk](const auto &item) -> bool { - auto iter = rk.prim_attrs_.find(item.first); - if (iter == rk.prim_attrs_.end()) { - return false; - } - if (item.second == iter->second) { - return true; - } - MS_EXCEPTION_IF_NULL(item.second); - MS_EXCEPTION_IF_NULL(iter->second); - return *item.second == *iter->second; - }); - return all; + return common::IsAttrsEqual(lk.prim_attrs_, rk.prim_attrs_); } }; diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 6f95149c0ce..7d4d56a28fe 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -745,30 +745,10 @@ bool AbstractClass::operator==(const AbstractClass &other) const { if (!(tag_ == other.tag_)) { return false; } - if (attributes_.size() != other.attributes_.size()) { + if (!common::IsAttrsEqual(attributes_, other.attributes_)) { return false; } - for (size_t i = 0; i < attributes_.size(); ++i) { - auto &attr1 = attributes_[i]; - auto &attr2 = other.attributes_[i]; - if (attr1.first != attr2.first || !IsEqual(attr1.second, attr2.second)) { - return false; - } - } - // Compare methods. - if (methods_.size() != other.methods_.size()) { - return false; - } - auto iter1 = methods_.begin(); - auto iter2 = other.methods_.begin(); - while (iter1 != methods_.end() && iter2 != other.methods_.end()) { - if (iter1->first != iter2->first || !IsEqual(iter1->second, iter2->second)) { - return false; - } - ++iter1; - ++iter2; - } - return true; + return common::IsAttrsEqual(methods_, other.methods_); } bool AbstractClass::operator==(const AbstractBase &other) const { @@ -1158,43 +1138,32 @@ ValuePtr AbstractKeywordArg::RealBuildValue() const { } std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) { - std::size_t hash_value = 0; - // Hashing all elements is costly, so only take at most 4 elements into account based on - // some experiments. - constexpr auto kMaxElementsNum = 4; - for (size_t i = 0; (i < args_spec_list.size()) && (i < kMaxElementsNum); i++) { - MS_EXCEPTION_IF_NULL(args_spec_list[i]); - hash_value = hash_combine(hash_value, args_spec_list[i]->hash()); + const size_t n_args = args_spec_list.size(); + std::size_t hash_value = n_args; + // Hashing all elements is costly, we only calculate hash from + // the first few elements base on some experiments. + constexpr size_t kMaxElementsNum = 4; + for (size_t i = 0; (i < n_args) && (i < kMaxElementsNum); ++i) { + const auto &arg = args_spec_list[i]; + MS_EXCEPTION_IF_NULL(arg); + hash_value = hash_combine(hash_value, arg->hash()); } return hash_value; } bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) { - if (lhs.size() != rhs.size()) { + const std::size_t size = lhs.size(); + if (size != rhs.size()) { return false; } - std::size_t size = lhs.size(); - for (std::size_t i = 0; i < size; i++) { - MS_EXCEPTION_IF_NULL(lhs[i]); - MS_EXCEPTION_IF_NULL(rhs[i]); - if (lhs[i] == rhs[i]) { - continue; - } - if (!(*lhs[i] == *rhs[i])) { + for (std::size_t i = 0; i < size; ++i) { + if (!IsEqual(lhs[i], rhs[i])) { return false; } } return true; } -std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const { - return AbstractBasePtrListHash(args_spec_list); -} - -bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { - return AbstractBasePtrListDeepEqual(lhs, rhs); -} - // RowTensor TypePtr AbstractRowTensor::BuildType() const { MS_EXCEPTION_IF_NULL(element()); diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index d3593627d7c..7fbf64e3b6d 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -1321,17 +1321,6 @@ class MS_CORE_API AbstractRef final : public AbstractTensor { }; using AbstractRefPtr = std::shared_ptr; -/// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts. -struct MS_CORE_API AbstractBasePtrListHasher { - std::size_t operator()(const AbstractBasePtrList &args_spec_list) const; -}; - -/// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to -/// another. -struct MS_CORE_API AbstractBasePtrListEqual { - bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const; -}; - /// \brief Compute the hash of a list of abstracts. /// /// \param[in] args_spec_list A list of abstracts. @@ -1345,6 +1334,21 @@ MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_ /// \return A boolean. MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); +/// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts. +struct AbstractBasePtrListHasher { + std::size_t operator()(const AbstractBasePtrList &args_spec_list) const { + return AbstractBasePtrListHash(args_spec_list); + } +}; + +/// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to +/// another. +struct AbstractBasePtrListEqual { + bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { + return AbstractBasePtrListDeepEqual(lhs, rhs); + } +}; + /// \brief Class AbstractRowTensor describes a RowTensor's abstract value. class MS_CORE_API AbstractRowTensor final : public AbstractUndetermined { public: diff --git a/mindspore/core/ir/cell.cc b/mindspore/core/ir/cell.cc index d27f10c21bb..423f93fe671 100644 --- a/mindspore/core/ir/cell.cc +++ b/mindspore/core/ir/cell.cc @@ -21,6 +21,7 @@ #include #include "abstract/abstract_value.h" +#include "utils/ms_utils.h" namespace mindspore { using mindspore::abstract::AbstractFunction; @@ -40,21 +41,7 @@ bool Cell::operator==(const Cell &other) const { if (name() != other.name()) { return false; } - if (attrs_.size() != other.attrs_.size()) { - return false; - } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const auto &item) { - if (item.second == nullptr) { - return false; - } - auto iter = other.attrs_.find(item.first); - if (iter == other.attrs_.end()) { - return false; - } - MS_EXCEPTION_IF_NULL(iter->second); - return *item.second == *iter->second; - }); - return all; + return common::IsAttrsEqual(attrs_, other.attrs_); } std::string Cell::GetAttrString() const { diff --git a/mindspore/core/ir/primitive.cc b/mindspore/core/ir/primitive.cc index 7bc8c8aa462..66d47110f03 100644 --- a/mindspore/core/ir/primitive.cc +++ b/mindspore/core/ir/primitive.cc @@ -18,6 +18,7 @@ #include #include "abstract/abstract_function.h" +#include "utils/ms_utils.h" namespace mindspore { static uint64_t MakeId() { @@ -73,20 +74,7 @@ bool Primitive::operator==(const Primitive &other) const { if (name() != other.name()) { return false; } - if (attrs_.size() != other.attrs_.size()) { - return false; - } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const auto &item) { - if (item.second == nullptr) { - return false; - } - auto iter = other.attrs_.find(item.first); - if (iter == other.attrs_.end()) { - return false; - } - return *item.second == *iter->second; - }); - return all; + return common::IsAttrsEqual(attrs_, other.attrs_); } std::string Primitive::GetAttrsText() const { diff --git a/mindspore/core/utils/ms_utils.h b/mindspore/core/utils/ms_utils.h index 8a2eb57e866..bd3ce3b7687 100644 --- a/mindspore/core/utils/ms_utils.h +++ b/mindspore/core/utils/ms_utils.h @@ -103,7 +103,7 @@ static inline bool CheckUseMPI() { } template -bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { +inline bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { if (a == b) { return true; } @@ -112,6 +112,29 @@ bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { } return *a == *b; } + +template +inline bool IsAttrsEqual(const T &a, const T &b) { + if (&a == &b) { + return true; + } + if (a.size() != b.size()) { + return false; + } + auto iter1 = a.begin(); + auto iter2 = b.begin(); + while (iter1 != a.end()) { + if (iter1->first != iter2->first) { + return false; + } + if (!IsEqual(iter1->second, iter2->second)) { + return false; + } + ++iter1; + ++iter2; + } + return true; +} } // namespace common } // namespace mindspore