Refactor attributes compare & AbstractBasePtrListHash/Equal

This commit is contained in:
He Wei 2021-12-16 10:19:16 +08:00
parent 2510e9f3d3
commit 883543d9dc
7 changed files with 62 additions and 123 deletions

View File

@ -23,6 +23,7 @@
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_info.h"
#include "utils/ms_utils.h"
namespace mindspore::graphkernel {
namespace {
@ -37,33 +38,15 @@ bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std:
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
return false;
}
auto main_attrs = main_primitive->attrs();
auto node_attrs = node_primitive->attrs();
std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"};
for (auto &attr : exclude_attrs) {
main_attrs.erase(attr);
node_attrs.erase(attr);
}
if (main_attrs.size() != node_attrs.size()) {
return false;
}
auto all = std::all_of(main_attrs.begin(), main_attrs.end(), [&node_attrs](const auto &item) -> bool {
if (item.second == nullptr) {
return false;
}
auto iter = node_attrs.find(item.first);
if (iter == node_attrs.end()) {
return false;
}
return *item.second == *iter->second;
});
return all;
return common::IsAttrsEqual(main_attrs, node_attrs);
}
return *main->inputs()[0] == *node->inputs()[0];
}
} // namespace

View File

@ -21,6 +21,7 @@
#include <memory>
#include <unordered_map>
#include "utils/hash_map.h"
#include "utils/ms_utils.h"
#include "ir/anf.h"
namespace mindspore::pynative {
@ -36,26 +37,10 @@ struct AbsCacheKeyHasher {
struct AbsCacheKeyEqual {
bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const {
if (lk.prim_attrs_.size() != rk.prim_attrs_.size()) {
return false;
}
if (lk.prim_name_ != rk.prim_name_) {
return false;
}
auto all = std::all_of(lk.prim_attrs_.begin(), lk.prim_attrs_.end(), [&rk](const auto &item) -> bool {
auto iter = rk.prim_attrs_.find(item.first);
if (iter == rk.prim_attrs_.end()) {
return false;
}
if (item.second == iter->second) {
return true;
}
MS_EXCEPTION_IF_NULL(item.second);
MS_EXCEPTION_IF_NULL(iter->second);
return *item.second == *iter->second;
});
return all;
return common::IsAttrsEqual(lk.prim_attrs_, rk.prim_attrs_);
}
};

View File

@ -745,30 +745,10 @@ bool AbstractClass::operator==(const AbstractClass &other) const {
if (!(tag_ == other.tag_)) {
return false;
}
if (attributes_.size() != other.attributes_.size()) {
if (!common::IsAttrsEqual(attributes_, other.attributes_)) {
return false;
}
for (size_t i = 0; i < attributes_.size(); ++i) {
auto &attr1 = attributes_[i];
auto &attr2 = other.attributes_[i];
if (attr1.first != attr2.first || !IsEqual(attr1.second, attr2.second)) {
return false;
}
}
// Compare methods.
if (methods_.size() != other.methods_.size()) {
return false;
}
auto iter1 = methods_.begin();
auto iter2 = other.methods_.begin();
while (iter1 != methods_.end() && iter2 != other.methods_.end()) {
if (iter1->first != iter2->first || !IsEqual(iter1->second, iter2->second)) {
return false;
}
++iter1;
++iter2;
}
return true;
return common::IsAttrsEqual(methods_, other.methods_);
}
bool AbstractClass::operator==(const AbstractBase &other) const {
@ -1158,43 +1138,32 @@ ValuePtr AbstractKeywordArg::RealBuildValue() const {
}
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) {
std::size_t hash_value = 0;
// Hashing all elements is costly, so only take at most 4 elements into account based on
// some experiments.
constexpr auto kMaxElementsNum = 4;
for (size_t i = 0; (i < args_spec_list.size()) && (i < kMaxElementsNum); i++) {
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
hash_value = hash_combine(hash_value, args_spec_list[i]->hash());
const size_t n_args = args_spec_list.size();
std::size_t hash_value = n_args;
// Hashing all elements is costly, we only calculate hash from
// the first few elements base on some experiments.
constexpr size_t kMaxElementsNum = 4;
for (size_t i = 0; (i < n_args) && (i < kMaxElementsNum); ++i) {
const auto &arg = args_spec_list[i];
MS_EXCEPTION_IF_NULL(arg);
hash_value = hash_combine(hash_value, arg->hash());
}
return hash_value;
}
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
if (lhs.size() != rhs.size()) {
const std::size_t size = lhs.size();
if (size != rhs.size()) {
return false;
}
std::size_t size = lhs.size();
for (std::size_t i = 0; i < size; i++) {
MS_EXCEPTION_IF_NULL(lhs[i]);
MS_EXCEPTION_IF_NULL(rhs[i]);
if (lhs[i] == rhs[i]) {
continue;
}
if (!(*lhs[i] == *rhs[i])) {
for (std::size_t i = 0; i < size; ++i) {
if (!IsEqual(lhs[i], rhs[i])) {
return false;
}
}
return true;
}
std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const {
return AbstractBasePtrListHash(args_spec_list);
}
bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
return AbstractBasePtrListDeepEqual(lhs, rhs);
}
// RowTensor
TypePtr AbstractRowTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element());

View File

@ -1321,17 +1321,6 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
};
using AbstractRefPtr = std::shared_ptr<AbstractRef>;
/// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts.
struct MS_CORE_API AbstractBasePtrListHasher {
std::size_t operator()(const AbstractBasePtrList &args_spec_list) const;
};
/// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to
/// another.
struct MS_CORE_API AbstractBasePtrListEqual {
bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const;
};
/// \brief Compute the hash of a list of abstracts.
///
/// \param[in] args_spec_list A list of abstracts.
@ -1345,6 +1334,21 @@ MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_
/// \return A boolean.
MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
/// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts.
struct AbstractBasePtrListHasher {
std::size_t operator()(const AbstractBasePtrList &args_spec_list) const {
return AbstractBasePtrListHash(args_spec_list);
}
};
/// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to
/// another.
struct AbstractBasePtrListEqual {
bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
return AbstractBasePtrListDeepEqual(lhs, rhs);
}
};
/// \brief Class AbstractRowTensor describes a RowTensor's abstract value.
class MS_CORE_API AbstractRowTensor final : public AbstractUndetermined {
public:

View File

@ -21,6 +21,7 @@
#include <algorithm>
#include "abstract/abstract_value.h"
#include "utils/ms_utils.h"
namespace mindspore {
using mindspore::abstract::AbstractFunction;
@ -40,21 +41,7 @@ bool Cell::operator==(const Cell &other) const {
if (name() != other.name()) {
return false;
}
if (attrs_.size() != other.attrs_.size()) {
return false;
}
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const auto &item) {
if (item.second == nullptr) {
return false;
}
auto iter = other.attrs_.find(item.first);
if (iter == other.attrs_.end()) {
return false;
}
MS_EXCEPTION_IF_NULL(iter->second);
return *item.second == *iter->second;
});
return all;
return common::IsAttrsEqual(attrs_, other.attrs_);
}
std::string Cell::GetAttrString() const {

View File

@ -18,6 +18,7 @@
#include <utility>
#include "abstract/abstract_function.h"
#include "utils/ms_utils.h"
namespace mindspore {
static uint64_t MakeId() {
@ -73,20 +74,7 @@ bool Primitive::operator==(const Primitive &other) const {
if (name() != other.name()) {
return false;
}
if (attrs_.size() != other.attrs_.size()) {
return false;
}
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const auto &item) {
if (item.second == nullptr) {
return false;
}
auto iter = other.attrs_.find(item.first);
if (iter == other.attrs_.end()) {
return false;
}
return *item.second == *iter->second;
});
return all;
return common::IsAttrsEqual(attrs_, other.attrs_);
}
std::string Primitive::GetAttrsText() const {

View File

@ -103,7 +103,7 @@ static inline bool CheckUseMPI() {
}
template <typename T>
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
inline bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
if (a == b) {
return true;
}
@ -112,6 +112,29 @@ bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
}
return *a == *b;
}
template <typename T>
inline bool IsAttrsEqual(const T &a, const T &b) {
if (&a == &b) {
return true;
}
if (a.size() != b.size()) {
return false;
}
auto iter1 = a.begin();
auto iter2 = b.begin();
while (iter1 != a.end()) {
if (iter1->first != iter2->first) {
return false;
}
if (!IsEqual(iter1->second, iter2->second)) {
return false;
}
++iter1;
++iter2;
}
return true;
}
} // namespace common
} // namespace mindspore