!48757 isclose计算过慢

Merge pull request !48757 from YingtongHu/master
This commit is contained in:
i-robot 2023-02-13 08:12:41 +00:00 committed by Gitee
commit 6a0f1e54f8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 28 deletions

View File

@ -19,10 +19,10 @@ from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.numpy.math_ops import _apply_tensor_op, absolute
from mindspore.numpy.array_creations import zeros, ones, empty, asarray
from mindspore.numpy.math_ops import _apply_tensor_op
from mindspore.numpy.array_creations import zeros, ones, asarray
from mindspore.numpy.utils import _check_input_tensor, _to_tensor, _isnan
from mindspore.numpy.utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape, _check_same_type, \
from mindspore.numpy.utils_const import _raise_type_error, _check_same_type, \
_check_axis_type, _canonicalize_axis, _can_broadcast, _isscalar
@ -489,31 +489,8 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
>>> print(np.isclose(a, b, equal_nan=True))
[ True True False False True True]
"""
a, b = _to_tensor(a, b)
if not isinstance(rtol, (int, float, bool)) or not isinstance(atol, (int, float, bool)):
_raise_type_error("rtol and atol are expected to be numbers.")
if not isinstance(equal_nan, bool):
_raise_type_error("equal_nan is expected to be bool.")
if _is_shape_empty(a.shape) or _is_shape_empty(b.shape):
return empty(_infer_out_shape(a.shape, b.shape), dtype=mstype.bool_)
rtol = _to_tensor(rtol).astype("float32")
atol = _to_tensor(atol).astype("float32")
res = absolute(a - b) <= (atol + rtol * absolute(b))
# infs are treated as equal
a_posinf = isposinf(a)
b_posinf = isposinf(b)
a_neginf = isneginf(a)
b_neginf = isneginf(b)
same_inf = F.logical_or(F.logical_and(a_posinf, b_posinf), F.logical_and(a_neginf, b_neginf))
diff_inf = F.logical_or(F.logical_and(a_posinf, b_neginf), F.logical_and(a_neginf, b_posinf))
res = F.logical_and(F.logical_or(res, same_inf), F.logical_not(diff_inf))
both_nan = F.logical_and(_isnan(a), _isnan(b))
if equal_nan:
res = F.logical_or(both_nan, res)
else:
res = F.logical_and(F.logical_not(both_nan), res)
return res
is_close = P.IsClose(rtol=rtol, atol=atol, equal_nan=equal_nan)
return is_close(a, b)
def in1d(ar1, ar2, invert=False):