add min/max value in abstract check equal function

This commit is contained in:
王南 2021-11-19 15:40:28 +08:00
parent d4e77ed507
commit ca39b11377
1 changed files with 14 additions and 1 deletions

View File

@ -612,7 +612,20 @@ bool AbstractTensor::equal_to(const AbstractTensor &other) const {
}
MS_EXCEPTION_IF_NULL(element_);
MS_EXCEPTION_IF_NULL(other.element_);
return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
bool is_elements_equal = (*element_ == *other.element_);
bool is_shape_equal = (*shape() == *other.shape());
if (!is_value_equal || !is_shape_equal || !is_elements_equal) {
return false;
}
// check min and max value
auto min_value = get_min_value();
auto max_value = get_max_value();
auto other_min_value = other.get_min_value();
auto other_max_value = other.get_max_value();
if (min_value != nullptr && max_value != nullptr && other_min_value != nullptr && other_max_value != nullptr) {
return (*min_value == *other_min_value) && (*max_value == *other_max_value);
}
return true;
}
bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }