!19223 fix linspace accuracy error

Merge pull request !19223 from 杨林枫/fix_linspace_accuracy_error
This commit is contained in:
i-robot 2021-07-02 01:33:52 +00:00 committed by Gitee
commit 5b98330f8d
3 changed files with 13 additions and 24 deletions

View File

@ -36,7 +36,7 @@ from .utils_const import _raise_value_error, _empty, _max, _min, \
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_slice, _raise_unimplemented_error, \
_tuple_setitem
from .array_ops import ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \
apply_along_axis, where
apply_along_axis, where, moveaxis
from .dtypes import nan, pi
# According to official numpy reference, the dimension of a numpy array must be less
@ -518,21 +518,18 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
delta = None
if num > 1:
delta = (stop - start) / div
start_expand = reshape(start, bounds_shape)
# This is similar to how numpy and jax compute linspace
if dtype in (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64):
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
delta_expand = reshape(delta, bounds_shape)
start_expand, incremental_expand, delta_expand = broadcast_arrays(
start_expand, incremental_expand, delta_expand)
out = start_expand + (incremental_expand * delta_expand)
else:
stop_expand = reshape(stop, bounds_shape)
step = reshape(_iota(mstype.float32, num), iota_shape) / div
start_expand, stop_expand, step = broadcast_arrays(
start_expand, stop_expand, step)
out = start_expand * (1 - step) + stop_expand * step
start_expand = reshape(start, bounds_shape)
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
delta_expand = reshape(delta, bounds_shape)
start_expand, incremental_expand, delta_expand = broadcast_arrays(
start_expand, incremental_expand, delta_expand)
out = start_expand + (incremental_expand * delta_expand)
# recover endpoint
if endpoint:
out = moveaxis(out, axis, 0)
out[-1] = stop
out = moveaxis(out, 0, axis)
elif num == 1:
delta = nan if endpoint else stop - start
out = reshape(start, bounds_shape)

View File

@ -469,7 +469,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
rtol (Number): The relative tolerance parameter (see Note).
atol (Number): The absolute tolerance parameter (see Note).
equal_nan (bool): Whether to compare ``NaN`` as equal. If True, ``NaN`` in
`a` will be considered equal to ``NaN`` in `b` in the output tensor.
`a` will be considered equal to ``NaN`` in `b` in the output tensor.
Returns:
A ``bool`` tensor of where `a` and `b` are equal within the given tolerance.

View File

@ -317,14 +317,6 @@ def test_arange():
expected = mnp.arange(0, 10).asnumpy()
match_array(actual, expected)
actual = onp.arange(start=10)
expected = mnp.arange(start=10).asnumpy()
match_array(actual, expected)
actual = onp.arange(start=10, step=0.1)
expected = mnp.arange(start=10, step=0.1).asnumpy()
match_array(actual, expected, error=6)
actual = onp.arange(10, step=0.1)
expected = mnp.arange(10, step=0.1).asnumpy()
match_array(actual, expected, error=6)