forked from mindspore-Ecosystem/mindspore
!19223 fix linspace accuracy error
Merge pull request !19223 from 杨林枫/fix_linspace_accuracy_error
This commit is contained in:
commit
5b98330f8d
|
@ -36,7 +36,7 @@ from .utils_const import _raise_value_error, _empty, _max, _min, \
|
|||
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_slice, _raise_unimplemented_error, \
|
||||
_tuple_setitem
|
||||
from .array_ops import ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \
|
||||
apply_along_axis, where
|
||||
apply_along_axis, where, moveaxis
|
||||
from .dtypes import nan, pi
|
||||
|
||||
# According to official numpy reference, the dimension of a numpy array must be less
|
||||
|
@ -518,21 +518,18 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
|||
delta = None
|
||||
if num > 1:
|
||||
delta = (stop - start) / div
|
||||
start_expand = reshape(start, bounds_shape)
|
||||
# This is similar to how numpy and jax compute linspace
|
||||
if dtype in (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64):
|
||||
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
|
||||
delta_expand = reshape(delta, bounds_shape)
|
||||
start_expand, incremental_expand, delta_expand = broadcast_arrays(
|
||||
start_expand, incremental_expand, delta_expand)
|
||||
out = start_expand + (incremental_expand * delta_expand)
|
||||
else:
|
||||
stop_expand = reshape(stop, bounds_shape)
|
||||
step = reshape(_iota(mstype.float32, num), iota_shape) / div
|
||||
start_expand, stop_expand, step = broadcast_arrays(
|
||||
start_expand, stop_expand, step)
|
||||
out = start_expand * (1 - step) + stop_expand * step
|
||||
start_expand = reshape(start, bounds_shape)
|
||||
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
|
||||
delta_expand = reshape(delta, bounds_shape)
|
||||
start_expand, incremental_expand, delta_expand = broadcast_arrays(
|
||||
start_expand, incremental_expand, delta_expand)
|
||||
out = start_expand + (incremental_expand * delta_expand)
|
||||
# recover endpoint
|
||||
if endpoint:
|
||||
out = moveaxis(out, axis, 0)
|
||||
out[-1] = stop
|
||||
out = moveaxis(out, 0, axis)
|
||||
elif num == 1:
|
||||
delta = nan if endpoint else stop - start
|
||||
out = reshape(start, bounds_shape)
|
||||
|
|
|
@ -469,7 +469,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|||
rtol (Number): The relative tolerance parameter (see Note).
|
||||
atol (Number): The absolute tolerance parameter (see Note).
|
||||
equal_nan (bool): Whether to compare ``NaN`` as equal. If True, ``NaN`` in
|
||||
`a` will be considered equal to ``NaN`` in `b` in the output tensor.
|
||||
`a` will be considered equal to ``NaN`` in `b` in the output tensor.
|
||||
|
||||
Returns:
|
||||
A ``bool`` tensor of where `a` and `b` are equal within the given tolerance.
|
||||
|
|
|
@ -317,14 +317,6 @@ def test_arange():
|
|||
expected = mnp.arange(0, 10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(start=10)
|
||||
expected = mnp.arange(start=10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(start=10, step=0.1)
|
||||
expected = mnp.arange(start=10, step=0.1).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
actual = onp.arange(10, step=0.1)
|
||||
expected = mnp.arange(10, step=0.1).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
|
Loading…
Reference in New Issue