fix float compare bug

This commit is contained in:
hangangqiang 2022-07-18 17:34:32 +08:00
parent 8566f04d51
commit 5aeb8e2818
4 changed files with 13 additions and 3 deletions

View File

@ -21,6 +21,8 @@
#include <cstdint>
#include <vector>
#include <set>
#include <limits>
#include <cmath>
#include <string>
#include <utility>
#include "src/common/log_adapter.h"
@ -241,6 +243,10 @@ inline size_t DataTypeSize(TypeId type) {
}
}
inline bool FloatCompare(const float &a, const float &b = 0.0f) {
return std::fabs(a - b) <= std::numeric_limits<float>::epsilon();
}
} // namespace lite
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <memory>
#include "ops/fusion/activation.h"
#include "nnacl/op_base.h"
#include "src/common/utils.h"
namespace mindspore {
namespace lite {
@ -30,7 +31,7 @@ PrimitiveCPtr CaffeReluParser::Parse(const caffe::LayerParameter &proto, const c
if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) {
float negative_slope = proto.relu_param().negative_slope();
if (negative_slope != 0) {
if (!FloatCompare(negative_slope)) {
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
prim->set_alpha(negative_slope);
}

View File

@ -18,6 +18,7 @@
#include <memory>
#include "ops/fusion/mat_mul_fusion.h"
#include "nnacl/op_base.h"
#include "src/common/utils.h"
namespace mindspore {
namespace lite {
@ -38,7 +39,8 @@ PrimitiveCPtr OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const
beta = onnx_node_attr.f();
}
}
if (alpha != 1 || (beta != 1 && !(onnx_node.input().size() == 2 && beta == 0))) { // 2 : input num is A and B
if (!FloatCompare(alpha, 1.0f) || (!FloatCompare(beta, 1.0f) && !(onnx_node.input().size() == 2 &&
!FloatCompare(beta)))) { // 2: input num is A and B
MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
return nullptr;
}

View File

@ -26,6 +26,7 @@
#include "nnacl/op_base.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "src/common/utils.h"
namespace mindspore {
namespace opt {
@ -174,7 +175,7 @@ bool IsPrimitiveProper(const CNodePtr &pad_cnode) {
}
pad_value = *static_cast<float *>(data_info.data_ptr_);
}
if (pad_value != 0) {
if (!mindspore::lite::FloatCompare(pad_value)) {
return false;
}