Optimize operator== for Abstract classes

This commit is contained in:
He Wei 2021-12-14 09:47:38 +08:00
parent 3655de93db
commit f12f99edfc
1 changed files with 120 additions and 207 deletions

View File

@ -23,12 +23,15 @@
#include "utils/hash_map.h"
#include "utils/symbolic.h"
#include "utils/ms_utils.h"
#include "abstract/utils.h"
#include "utils/ms_context.h"
#include "utils/trace_base.h"
namespace mindspore {
namespace abstract {
using mindspore::common::IsEqual;
AnfNodePtr GetTraceNode(const AbstractBasePtr &abs) {
AnfNodePtr node = nullptr;
if (mindspore::abstract::AbstractBase::trace_node_provider_ != nullptr) {
@ -87,44 +90,14 @@ std::string ExtractLoggingInfo(const std::string &info) {
}
bool AbstractBase::operator==(const AbstractBase &other) const {
if (tid() != other.tid()) {
return false;
}
auto type = BuildType();
auto other_type = BuildType();
MS_EXCEPTION_IF_NULL(other_type);
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() == kObjectTypeUndeterminedType && other_type->type_id() == kObjectTypeUndeterminedType) {
if (this == &other) {
// Same object.
return true;
}
if (value_ == nullptr || other.value_ == nullptr) {
MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
<< this->ToString() << ", other: " << other.ToString();
}
bool value_equal = false;
if (value_ == other.value_) {
value_equal = true;
} else if (*value_ == *other.value_) {
value_equal = true;
}
bool type_equal = false;
MS_EXCEPTION_IF_NULL(type_);
MS_EXCEPTION_IF_NULL(other.type_);
if (type_ == other.type_) {
type_equal = true;
} else if (*type_ == *other.type_) {
type_equal = true;
}
bool shape_equal = false;
MS_EXCEPTION_IF_NULL(shape_);
MS_EXCEPTION_IF_NULL(other.shape_);
if (shape_ == other.shape_) {
shape_equal = true;
} else if (*shape_ == *other.shape_) {
shape_equal = true;
}
return value_equal && type_equal && shape_equal;
return tid() == other.tid() && // c++ type equal and
IsEqual(type_, other.type_) && // type equal and
IsEqual(shape_, other.shape_) && // shape equal and
IsEqual(value_, other.value_); // value equal.
}
ValuePtr AbstractBase::BuildValue() const {
@ -199,23 +172,10 @@ AbstractBasePtr AbstractType::Clone() const {
}
bool AbstractType::operator==(const AbstractBase &other) const {
if (tid() != other.tid()) {
return false;
if (this == &other) {
return true;
}
// Have to compare TypePtr with value;
ValuePtr value_self = GetValueTrack();
ValuePtr value_other = other.GetValueTrack();
if (value_self == nullptr || value_other == nullptr) {
MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString()
<< ", other: " << other.ToString();
}
if (!value_self->isa<Type>() || !value_other->isa<Type>()) {
return false;
}
TypePtr type_self = value_self->cast<TypePtr>();
TypePtr type_other = value_other->cast<TypePtr>();
bool value_equal = *type_self == *type_other;
return value_equal;
return tid() == other.tid() && IsEqual(dyn_cast<Type>(GetValueTrack()), dyn_cast<Type>(other.GetValueTrack()));
}
std::string AbstractType::ToString() const {
@ -255,12 +215,13 @@ AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
}
bool AbstractFunction::operator==(const AbstractBase &other) const {
if (this == &other) {
return true;
}
if (!other.isa<AbstractFunction>()) {
return false;
}
const auto &other_func = static_cast<const AbstractFunction &>(other);
bool value_equal = (*this == other_func);
return value_equal;
return *this == static_cast<const AbstractFunction &>(other);
}
const AbstractBasePtr AbstractSequence::operator[](const std::size_t &dim) const {
@ -387,17 +348,14 @@ std::size_t AbstractSequence::hash() const {
}
bool AbstractSequence::operator==(const AbstractSequence &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (elements_.size() != other.elements_.size()) {
return false;
}
for (size_t i = 0; i < elements_.size(); i++) {
MS_EXCEPTION_IF_NULL(elements_[i]);
MS_EXCEPTION_IF_NULL(other.elements_[i]);
if (!(*(elements_[i]) == *(other.elements_[i]))) {
for (size_t i = 0; i < elements_.size(); ++i) {
if (!IsEqual(elements_[i], other.elements_[i])) {
return false;
}
}
@ -407,16 +365,13 @@ bool AbstractSequence::operator==(const AbstractSequence &other) const {
bool AbstractTuple::operator==(const AbstractTuple &other) const { return AbstractSequence::operator==(other); }
bool AbstractTuple::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractTuple>()) {
auto other_tuple = static_cast<const AbstractTuple *>(&other);
return *this == *other_tuple;
if (!other.isa<AbstractTuple>()) {
return false;
}
return false;
return AbstractSequence::operator==(static_cast<const AbstractSequence &>(other));
}
bool AbstractTuple::ContainsAllBroadenTensors() const {
@ -433,15 +388,13 @@ bool AbstractTuple::ContainsAllBroadenTensors() const {
bool AbstractList::operator==(const AbstractList &other) const { return AbstractSequence::operator==(other); }
bool AbstractList::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractList>()) {
auto other_list = static_cast<const AbstractList *>(&other);
return *this == *other_list;
if (!other.isa<AbstractList>()) {
return false;
}
return false;
return AbstractSequence::operator==(static_cast<const AbstractSequence &>(other));
}
TypePtr AbstractSlice::BuildType() const {
@ -455,21 +408,20 @@ TypePtr AbstractSlice::BuildType() const {
}
bool AbstractSlice::operator==(const AbstractSlice &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_);
return IsEqual(start_, other.start_) && IsEqual(stop_, other.stop_) && IsEqual(step_, other.step_);
}
bool AbstractSlice::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (!other.isa<AbstractSlice>()) {
return false;
}
auto other_slice = static_cast<const AbstractSlice *>(&other);
return *this == *other_slice;
return *this == static_cast<const AbstractSlice &>(other);
}
AbstractBasePtr AbstractSlice::Clone() const {
@ -597,51 +549,37 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
bool AbstractTensor::equal_to(const AbstractTensor &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
// Check value. for AbstractTensor, both value should be AnyValue.
auto v1 = GetValueTrack();
auto v2 = other.GetValueTrack();
if (v1 == nullptr || v2 == nullptr) {
MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr";
}
bool is_value_equal = (v1 == v2);
if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
is_value_equal = true;
}
MS_EXCEPTION_IF_NULL(element_);
MS_EXCEPTION_IF_NULL(other.element_);
bool is_elements_equal = (*element_ == *other.element_);
bool is_shape_equal = (*shape() == *other.shape());
if (!is_value_equal || !is_shape_equal || !is_elements_equal) {
if (v1 != v2 && (v1 == nullptr || !v1->isa<AnyValue>() || v2 == nullptr || !v2->isa<AnyValue>())) {
return false;
}
// check min and max value
auto min_value = get_min_value();
auto max_value = get_max_value();
auto other_min_value = other.get_min_value();
auto other_max_value = other.get_max_value();
if (min_value != nullptr && max_value != nullptr && other_min_value != nullptr && other_max_value != nullptr) {
return (*min_value == *other_min_value) && (*max_value == *other_max_value);
// Check element type.
if (!IsEqual(element_, other.element_)) {
return false;
}
return true;
// Check shape.
if (!IsEqual(shape(), other.shape())) {
return false;
}
// Check min and max values.
return IsEqual(get_min_value(), other.get_min_value()) && IsEqual(get_max_value(), other.get_max_value());
}
bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
bool AbstractTensor::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.tid() == tid()) {
auto other_tensor = static_cast<const AbstractTensor *>(&other);
return *this == *other_tensor;
} else {
if (tid() != other.tid()) {
return false;
}
return equal_to(static_cast<const AbstractTensor &>(other));
}
AbstractBasePtr AbstractTensor::Clone() const {
@ -704,14 +642,10 @@ bool AbstractDictionary::operator==(const AbstractDictionary &other) const {
if (key_values_.size() != other.key_values_.size()) {
return false;
}
for (size_t index = 0; index < key_values_.size(); index++) {
if (key_values_[index].first != other.key_values_[index].first) {
return false;
}
MS_EXCEPTION_IF_NULL(key_values_[index].second);
MS_EXCEPTION_IF_NULL(other.key_values_[index].second);
if (!(*key_values_[index].second == *other.key_values_[index].second)) {
for (size_t index = 0; index < key_values_.size(); ++index) {
auto &kv1 = key_values_[index];
auto &kv2 = other.key_values_[index];
if (kv1.first != kv2.first || !IsEqual(kv1.second, kv2.second)) {
return false;
}
}
@ -719,14 +653,13 @@ bool AbstractDictionary::operator==(const AbstractDictionary &other) const {
}
bool AbstractDictionary::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractDictionary>()) {
auto other_class = static_cast<const AbstractDictionary *>(&other);
return *this == *other_class;
if (!other.isa<AbstractDictionary>()) {
return false;
}
return false;
return *this == static_cast<const AbstractDictionary &>(other);
}
AbstractBasePtr AbstractDictionary::Clone() const {
@ -804,37 +737,37 @@ bool AbstractClass::operator==(const AbstractClass &other) const {
if (attributes_.size() != other.attributes_.size()) {
return false;
}
for (size_t i = 0; i < attributes_.size(); i++) {
MS_EXCEPTION_IF_NULL(attributes_[i].second);
MS_EXCEPTION_IF_NULL(other.attributes_[i].second);
if (!(*attributes_[i].second == *other.attributes_[i].second)) {
MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString()
<< " arg2:" << other.attributes_[i].second->ToString();
for (size_t i = 0; i < attributes_.size(); ++i) {
auto &attr1 = attributes_[i];
auto &attr2 = other.attributes_[i];
if (attr1.first != attr2.first || !IsEqual(attr1.second, attr2.second)) {
return false;
}
}
// method compare;
// Compare methods.
if (methods_.size() != other.methods_.size()) {
return false;
}
for (const auto &iter : methods_) {
auto iter_other = other.methods_.find(iter.first);
if (iter_other == other.methods_.end()) {
return false;
}
if (!(*iter.second == *iter_other->second)) {
auto iter1 = methods_.begin();
auto iter2 = other.methods_.begin();
while (iter1 != methods_.end() && iter2 != other.methods_.end()) {
if (iter1->first != iter2->first || !IsEqual(iter1->second, iter2->second)) {
return false;
}
++iter1;
++iter2;
}
return true;
}
bool AbstractClass::operator==(const AbstractBase &other) const {
if (other.isa<AbstractClass>()) {
auto other_class = static_cast<const AbstractClass *>(&other);
return *this == *other_class;
if (this == &other) {
return true;
}
return false;
if (!other.isa<AbstractClass>()) {
return false;
}
return *this == static_cast<const AbstractClass &>(other);
}
AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) {
@ -949,18 +882,16 @@ AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
return std::make_shared<AbstractJTagged>(joined_elem);
}
bool AbstractJTagged::operator==(const AbstractJTagged &other) const {
MS_EXCEPTION_IF_NULL(element_);
MS_EXCEPTION_IF_NULL(other.element_);
return (*element_ == *other.element_);
}
bool AbstractJTagged::operator==(const AbstractJTagged &other) const { return IsEqual(element_, other.element_); }
bool AbstractJTagged::operator==(const AbstractBase &other) const {
if (other.isa<AbstractJTagged>()) {
auto other_jtagged = static_cast<const AbstractJTagged *>(&other);
return *this == *other_jtagged;
if (this == &other) {
return true;
}
return false;
if (!other.isa<AbstractJTagged>()) {
return false;
}
return *this == static_cast<const AbstractJTagged &>(other);
}
std::string AbstractJTagged::ToString() const {
@ -987,15 +918,20 @@ TypePtr AbstractRef::BuildType() const {
}
bool AbstractRef::operator==(const AbstractRef &other) const {
return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_);
if (this == &other) {
return true;
}
return IsEqual(ref_key_, other.ref_key_) && AbstractTensor::equal_to(other);
}
bool AbstractRef::operator==(const AbstractBase &other) const {
if (other.isa<AbstractRef>()) {
auto other_conf = static_cast<const AbstractRef *>(&other);
return *this == *other_conf;
if (this == &other) {
return true;
}
return false;
if (!other.isa<AbstractRef>()) {
return false;
}
return *this == static_cast<const AbstractRef &>(other);
}
AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
@ -1057,11 +993,10 @@ AbstractBasePtr AbstractRef::PartialBroaden() const { return Clone(); }
bool AbstractNone::operator==(const AbstractNone &) const { return true; }
bool AbstractNone::operator==(const AbstractBase &other) const {
if (other.isa<AbstractNone>()) {
auto other_none = static_cast<const AbstractNone *>(&other);
return *this == *other_none;
if (this == &other) {
return true;
}
return false;
return other.isa<AbstractNone>();
}
std::string AbstractNone::ToString() const {
@ -1073,31 +1008,28 @@ std::string AbstractNone::ToString() const {
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
ValuePtr value_self = GetValueTrack();
ValuePtr value_other = other.GetValueTrack();
if (value_self != nullptr && value_other != nullptr) {
if (value_self->isa<AnyValue>() && value_other->isa<AnyValue>()) {
return true;
}
if (!value_self->isa<RefKey>() || !value_other->isa<RefKey>()) {
return false;
}
RefKeyPtr type_self = value_self->cast<RefKeyPtr>();
RefKeyPtr type_other = value_other->cast<RefKeyPtr>();
return *type_self == *type_other;
} else if (value_self != nullptr || value_other != nullptr) {
ValuePtr v1 = GetValueTrack();
ValuePtr v2 = other.GetValueTrack();
if (v1 == v2) {
return true;
}
if (v1 == nullptr || v2 == nullptr) {
return false;
}
return true;
if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
return true;
}
return IsEqual(dyn_cast<RefKey>(v1), dyn_cast<RefKey>(v2));
}
bool AbstractRefKey::operator==(const AbstractBase &other) const {
if (other.isa<AbstractRefKey>()) {
auto other_confkey = static_cast<const AbstractRefKey *>(&other);
return *this == *other_confkey;
} else {
if (this == &other) {
return true;
}
if (!other.isa<AbstractRefKey>()) {
return false;
}
return *this == static_cast<const AbstractRefKey &>(other);
}
std::string AbstractRefKey::ToString() const {
@ -1113,15 +1045,10 @@ std::string AbstractRefKey::ToString() const {
bool AbstractNull::operator==(const AbstractNull &) const { return true; }
bool AbstractNull::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractNull>()) {
auto other_none = static_cast<const AbstractNull *>(&other);
return *this == *other_none;
} else {
return false;
}
return other.isa<AbstractNull>();
}
std::string AbstractNull::ToString() const {
@ -1133,15 +1060,10 @@ std::string AbstractNull::ToString() const {
bool AbstractTimeOut::operator==(const AbstractTimeOut &) const { return true; }
bool AbstractTimeOut::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractTimeOut>()) {
auto other_none = static_cast<const AbstractTimeOut *>(&other);
return *this == *other_none;
} else {
return false;
}
return other.isa<AbstractTimeOut>();
}
std::string AbstractTimeOut::ToString() const {
@ -1154,15 +1076,10 @@ std::string AbstractTimeOut::ToString() const {
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
bool AbstractEllipsis::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractEllipsis>()) {
auto other_none = static_cast<const AbstractEllipsis *>(&other);
return *this == *other_none;
} else {
return false;
}
return other.isa<AbstractEllipsis>();
}
std::string AbstractEllipsis::ToString() const {
@ -1203,24 +1120,20 @@ std::string AbstractKeywordArg::ToString() const {
}
bool AbstractKeywordArg::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
if (other.isa<AbstractKeywordArg>()) {
auto other_tuple = static_cast<const AbstractKeywordArg *>(&other);
return *this == *other_tuple;
if (!other.isa<AbstractKeywordArg>()) {
return false;
}
return false;
return *this == static_cast<const AbstractKeywordArg &>(other);
}
bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
MS_EXCEPTION_IF_NULL(arg_value_);
MS_EXCEPTION_IF_NULL(other.arg_value_);
return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_;
return other.arg_name_ == arg_name_ && IsEqual(other.arg_value_, arg_value_);
}
ValuePtr AbstractKeywordArg::RealBuildValue() const {
@ -1579,7 +1492,7 @@ AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) {
bool AbstractUMonad::operator==(const AbstractUMonad &) const { return true; }
bool AbstractUMonad::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
return other.isa<AbstractUMonad>();
@ -1600,7 +1513,7 @@ AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) {
bool AbstractIOMonad::operator==(const AbstractIOMonad &) const { return true; }
bool AbstractIOMonad::operator==(const AbstractBase &other) const {
if (&other == this) {
if (this == &other) {
return true;
}
return other.isa<AbstractIOMonad>();