forked from mindspore-Ecosystem/mindspore
Refactor attributes compare & AbstractBasePtrListHash/Equal
This commit is contained in:
parent
2510e9f3d3
commit
883543d9dc
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
@ -37,33 +38,15 @@ bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std:
|
|||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto main_attrs = main_primitive->attrs();
|
||||
auto node_attrs = node_primitive->attrs();
|
||||
|
||||
std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"};
|
||||
for (auto &attr : exclude_attrs) {
|
||||
main_attrs.erase(attr);
|
||||
node_attrs.erase(attr);
|
||||
}
|
||||
|
||||
if (main_attrs.size() != node_attrs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto all = std::all_of(main_attrs.begin(), main_attrs.end(), [&node_attrs](const auto &item) -> bool {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = node_attrs.find(item.first);
|
||||
if (iter == node_attrs.end()) {
|
||||
return false;
|
||||
}
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
return common::IsAttrsEqual(main_attrs, node_attrs);
|
||||
}
|
||||
|
||||
return *main->inputs()[0] == *node->inputs()[0];
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
|
@ -36,26 +37,10 @@ struct AbsCacheKeyHasher {
|
|||
|
||||
struct AbsCacheKeyEqual {
|
||||
bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const {
|
||||
if (lk.prim_attrs_.size() != rk.prim_attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
if (lk.prim_name_ != rk.prim_name_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto all = std::all_of(lk.prim_attrs_.begin(), lk.prim_attrs_.end(), [&rk](const auto &item) -> bool {
|
||||
auto iter = rk.prim_attrs_.find(item.first);
|
||||
if (iter == rk.prim_attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (item.second == iter->second) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
return common::IsAttrsEqual(lk.prim_attrs_, rk.prim_attrs_);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -745,30 +745,10 @@ bool AbstractClass::operator==(const AbstractClass &other) const {
|
|||
if (!(tag_ == other.tag_)) {
|
||||
return false;
|
||||
}
|
||||
if (attributes_.size() != other.attributes_.size()) {
|
||||
if (!common::IsAttrsEqual(attributes_, other.attributes_)) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < attributes_.size(); ++i) {
|
||||
auto &attr1 = attributes_[i];
|
||||
auto &attr2 = other.attributes_[i];
|
||||
if (attr1.first != attr2.first || !IsEqual(attr1.second, attr2.second)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Compare methods.
|
||||
if (methods_.size() != other.methods_.size()) {
|
||||
return false;
|
||||
}
|
||||
auto iter1 = methods_.begin();
|
||||
auto iter2 = other.methods_.begin();
|
||||
while (iter1 != methods_.end() && iter2 != other.methods_.end()) {
|
||||
if (iter1->first != iter2->first || !IsEqual(iter1->second, iter2->second)) {
|
||||
return false;
|
||||
}
|
||||
++iter1;
|
||||
++iter2;
|
||||
}
|
||||
return true;
|
||||
return common::IsAttrsEqual(methods_, other.methods_);
|
||||
}
|
||||
|
||||
bool AbstractClass::operator==(const AbstractBase &other) const {
|
||||
|
@ -1158,43 +1138,32 @@ ValuePtr AbstractKeywordArg::RealBuildValue() const {
|
|||
}
|
||||
|
||||
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) {
|
||||
std::size_t hash_value = 0;
|
||||
// Hashing all elements is costly, so only take at most 4 elements into account based on
|
||||
// some experiments.
|
||||
constexpr auto kMaxElementsNum = 4;
|
||||
for (size_t i = 0; (i < args_spec_list.size()) && (i < kMaxElementsNum); i++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||
hash_value = hash_combine(hash_value, args_spec_list[i]->hash());
|
||||
const size_t n_args = args_spec_list.size();
|
||||
std::size_t hash_value = n_args;
|
||||
// Hashing all elements is costly, we only calculate hash from
|
||||
// the first few elements base on some experiments.
|
||||
constexpr size_t kMaxElementsNum = 4;
|
||||
for (size_t i = 0; (i < n_args) && (i < kMaxElementsNum); ++i) {
|
||||
const auto &arg = args_spec_list[i];
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
hash_value = hash_combine(hash_value, arg->hash());
|
||||
}
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
const std::size_t size = lhs.size();
|
||||
if (size != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
std::size_t size = lhs.size();
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
MS_EXCEPTION_IF_NULL(lhs[i]);
|
||||
MS_EXCEPTION_IF_NULL(rhs[i]);
|
||||
if (lhs[i] == rhs[i]) {
|
||||
continue;
|
||||
}
|
||||
if (!(*lhs[i] == *rhs[i])) {
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
if (!IsEqual(lhs[i], rhs[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const {
|
||||
return AbstractBasePtrListHash(args_spec_list);
|
||||
}
|
||||
|
||||
bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
|
||||
return AbstractBasePtrListDeepEqual(lhs, rhs);
|
||||
}
|
||||
|
||||
// RowTensor
|
||||
TypePtr AbstractRowTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
|
|
|
@ -1321,17 +1321,6 @@ class MS_CORE_API AbstractRef final : public AbstractTensor {
|
|||
};
|
||||
using AbstractRefPtr = std::shared_ptr<AbstractRef>;
|
||||
|
||||
/// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts.
|
||||
struct MS_CORE_API AbstractBasePtrListHasher {
|
||||
std::size_t operator()(const AbstractBasePtrList &args_spec_list) const;
|
||||
};
|
||||
|
||||
/// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to
|
||||
/// another.
|
||||
struct MS_CORE_API AbstractBasePtrListEqual {
|
||||
bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const;
|
||||
};
|
||||
|
||||
/// \brief Compute the hash of a list of abstracts.
|
||||
///
|
||||
/// \param[in] args_spec_list A list of abstracts.
|
||||
|
@ -1345,6 +1334,21 @@ MS_CORE_API std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_
|
|||
/// \return A boolean.
|
||||
MS_CORE_API bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
|
||||
|
||||
/// \brief Struct AbstractBasePtrListHasher provides a function to compute the hash of a list of abstracts.
|
||||
struct AbstractBasePtrListHasher {
|
||||
std::size_t operator()(const AbstractBasePtrList &args_spec_list) const {
|
||||
return AbstractBasePtrListHash(args_spec_list);
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief Struct AbstractBasePtrListEqual provides a function to determine whether a list of abstracts is equal to
|
||||
/// another.
|
||||
struct AbstractBasePtrListEqual {
|
||||
bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
|
||||
return AbstractBasePtrListDeepEqual(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief Class AbstractRowTensor describes a RowTensor's abstract value.
|
||||
class MS_CORE_API AbstractRowTensor final : public AbstractUndetermined {
|
||||
public:
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
|
@ -40,21 +41,7 @@ bool Cell::operator==(const Cell &other) const {
|
|||
if (name() != other.name()) {
|
||||
return false;
|
||||
}
|
||||
if (attrs_.size() != other.attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const auto &item) {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = other.attrs_.find(item.first);
|
||||
if (iter == other.attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
return common::IsAttrsEqual(attrs_, other.attrs_);
|
||||
}
|
||||
|
||||
std::string Cell::GetAttrString() const {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <utility>
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
static uint64_t MakeId() {
|
||||
|
@ -73,20 +74,7 @@ bool Primitive::operator==(const Primitive &other) const {
|
|||
if (name() != other.name()) {
|
||||
return false;
|
||||
}
|
||||
if (attrs_.size() != other.attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const auto &item) {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = other.attrs_.find(item.first);
|
||||
if (iter == other.attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
return common::IsAttrsEqual(attrs_, other.attrs_);
|
||||
}
|
||||
|
||||
std::string Primitive::GetAttrsText() const {
|
||||
|
|
|
@ -103,7 +103,7 @@ static inline bool CheckUseMPI() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
||||
inline bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
||||
if (a == b) {
|
||||
return true;
|
||||
}
|
||||
|
@ -112,6 +112,29 @@ bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
|||
}
|
||||
return *a == *b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool IsAttrsEqual(const T &a, const T &b) {
|
||||
if (&a == &b) {
|
||||
return true;
|
||||
}
|
||||
if (a.size() != b.size()) {
|
||||
return false;
|
||||
}
|
||||
auto iter1 = a.begin();
|
||||
auto iter2 = b.begin();
|
||||
while (iter1 != a.end()) {
|
||||
if (iter1->first != iter2->first) {
|
||||
return false;
|
||||
}
|
||||
if (!IsEqual(iter1->second, iter2->second)) {
|
||||
return false;
|
||||
}
|
||||
++iter1;
|
||||
++iter2;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue