Fix float max value compare

This commit is contained in:
fary86 2020-07-23 00:59:02 +08:00
parent c24122bb55
commit e470fbf2bc
2 changed files with 13 additions and 2 deletions

View File

@ -115,6 +115,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this(); std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
while (fg != nullptr && fg != specializer->func_graph_) { while (fg != nullptr && fg != specializer->func_graph_) {
specializer = specializer->parent_; specializer = specializer->parent_;
MS_EXCEPTION_IF_NULL(specializer);
} }
// If had replicated, just return that. // If had replicated, just return that.
auto iter = specializer->repl_node_->find(node); auto iter = specializer->repl_node_->find(node);

View File

@ -130,7 +130,12 @@ bool FP32Imm::operator==(const Value &other) const {
return false; return false;
} }
} }
bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } bool FP32Imm::operator==(const FP32Imm &other) const {
if (std::isinf(v_) && std::isinf(other.v_)) {
return true;
}
return fabs(v_ - other.v_) < FLT_EPSILON;
}
bool FP64Imm::operator==(const Value &other) const { bool FP64Imm::operator==(const Value &other) const {
if (other.isa<FP64Imm>()) { if (other.isa<FP64Imm>()) {
auto other_ = static_cast<const FP64Imm &>(other); auto other_ = static_cast<const FP64Imm &>(other);
@ -179,7 +184,12 @@ std::string ValueSequeue::DumpText() const {
return oss.str(); return oss.str();
} }
bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } bool FP64Imm::operator==(const FP64Imm &other) const {
if (std::isinf(v_) && std::isinf(other.v_)) {
return true;
}
return fabs(v_ - other.v_) < DBL_EPSILON;
}
bool StringImm::operator==(const Value &other) const { bool StringImm::operator==(const Value &other) const {
if (other.isa<StringImm>()) { if (other.isa<StringImm>()) {
auto other_ = static_cast<const StringImm &>(other); auto other_ = static_cast<const StringImm &>(other);