diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.cc b/mindspore/ccsrc/frontend/operator/cc_implementations.cc index 3f91632c799..e8e32afb694 100644 --- a/mindspore/ccsrc/frontend/operator/cc_implementations.cc +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.cc @@ -140,7 +140,7 @@ T InnerScalarPow(T x, U y) { template bool InnerScalarEq(T x, U y) { if (std::isinf(static_cast(x)) && std::isinf(static_cast(y))) { - return true; + return (x > 0 && y > 0) || (x < 0 && y < 0); } double error = static_cast(x) - static_cast(y); error = fabs(error); diff --git a/tests/ut/python/graph_syntax/operators/test_operator.py b/tests/ut/python/graph_syntax/operators/test_operator.py index 9353cdfee2c..2b2447d3ec8 100644 --- a/tests/ut/python/graph_syntax/operators/test_operator.py +++ b/tests/ut/python/graph_syntax/operators/test_operator.py @@ -212,8 +212,10 @@ def test_equal_inf(): Expectation: success """ @ms.jit - def func(x): - return x == float("inf") + def func(x, y): + return x == float("inf"), y == float("-inf"), x == y x = float("inf") - assert func(x) + y = float("-inf") + out = func(x, y) + assert out == (True, True, False)