forked from mindspore-Ecosystem/mindspore
additional typecheck
This commit is contained in:
parent
6f834db9b2
commit
f34756c3bd
|
@ -1200,6 +1200,8 @@ def moveaxis(a, source, destination):
|
||||||
ndim = F.rank(a)
|
ndim = F.rank(a)
|
||||||
source = _check_axis_valid(source, ndim)
|
source = _check_axis_valid(source, ndim)
|
||||||
destination = _check_axis_valid(destination, ndim)
|
destination = _check_axis_valid(destination, ndim)
|
||||||
|
if len(source) != len(destination):
|
||||||
|
_raise_value_error('`source` and `destination` arguments must have the same number of elements')
|
||||||
perm = _get_moved_perm(ndim, source, destination)
|
perm = _get_moved_perm(ndim, source, destination)
|
||||||
|
|
||||||
shape = F.shape(a)
|
shape = F.shape(a)
|
||||||
|
@ -1305,7 +1307,7 @@ def broadcast_to(array, shape):
|
||||||
"""
|
"""
|
||||||
shape_a = F.shape(array)
|
shape_a = F.shape(array)
|
||||||
if not _check_can_broadcast_to(shape_a, shape):
|
if not _check_can_broadcast_to(shape_a, shape):
|
||||||
return _raise_value_error('cannot broadcaast with {shape_a} {shape}')
|
return _raise_value_error('cannot broadcast with ', shape)
|
||||||
return _broadcast_to_shape(array, shape)
|
return _broadcast_to_shape(array, shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1386,6 +1388,7 @@ def split(x, indices_or_sections, axis=0):
|
||||||
Tensor(shape=[3], dtype=Float32,
|
Tensor(shape=[3], dtype=Float32,
|
||||||
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
||||||
"""
|
"""
|
||||||
|
_check_input_tensor(x)
|
||||||
_ = _check_axis_type(axis, True, False, False)
|
_ = _check_axis_type(axis, True, False, False)
|
||||||
axis = _canonicalize_axis(axis, x.ndim)
|
axis = _canonicalize_axis(axis, x.ndim)
|
||||||
res = None
|
res = None
|
||||||
|
@ -1827,6 +1830,8 @@ def take(a, indices, axis=None, mode='raise'):
|
||||||
[5 7]]
|
[5 7]]
|
||||||
"""
|
"""
|
||||||
_check_input_tensor(a, indices)
|
_check_input_tensor(a, indices)
|
||||||
|
if mode not in ('raise', 'wrap', 'clip'):
|
||||||
|
_raise_value_error('raise should be one of "raise", "wrap", or "clip"')
|
||||||
if axis is None:
|
if axis is None:
|
||||||
a = ravel(a)
|
a = ravel(a)
|
||||||
axis = 0
|
axis = 0
|
||||||
|
|
|
@ -1149,11 +1149,13 @@ def ptp(x, axis=None, out=None, keepdims=False):
|
||||||
[2. 0. 5. 2.]
|
[2. 0. 5. 2.]
|
||||||
"""
|
"""
|
||||||
_check_input_tensor(x)
|
_check_input_tensor(x)
|
||||||
|
if not isinstance(keepdims, bool):
|
||||||
|
_raise_type_error('keepdims should be boolean')
|
||||||
if axis is None:
|
if axis is None:
|
||||||
axis = ()
|
axis = ()
|
||||||
else:
|
else:
|
||||||
_check_axis_type(axis, True, True, False)
|
_check_axis_type(axis, True, True, False)
|
||||||
axis = _canonicalize_axis(axis, x.ndim)
|
axis = _check_axis_valid(axis, x.ndim)
|
||||||
|
|
||||||
if keepdims:
|
if keepdims:
|
||||||
x_min = _reduce_min_keepdims(x, axis)
|
x_min = _reduce_min_keepdims(x, axis)
|
||||||
|
|
|
@ -165,19 +165,18 @@ def _check_axis_valid(axes, ndim):
|
||||||
Checks axes are valid given ndim, and returns axes that can be passed
|
Checks axes are valid given ndim, and returns axes that can be passed
|
||||||
to the built-in operator (non-negative, int or tuple)
|
to the built-in operator (non-negative, int or tuple)
|
||||||
"""
|
"""
|
||||||
if isinstance(axes, int):
|
if axes is None:
|
||||||
_check_axis_in_range(axes, ndim)
|
axes = F.make_range(ndim)
|
||||||
return (axes % ndim,)
|
return axes
|
||||||
if isinstance(axes, (tuple, list)):
|
if isinstance(axes, (tuple, list)):
|
||||||
for axis in axes:
|
for axis in axes:
|
||||||
_check_axis_in_range(axis, ndim)
|
_check_axis_in_range(axis, ndim)
|
||||||
axes = tuple(map(lambda x: x % ndim, axes))
|
axes = tuple(map(lambda x: x % ndim, axes))
|
||||||
if all(axes.count(el) <= 1 for el in axes):
|
if any(axes.count(el) > 1 for el in axes):
|
||||||
return axes
|
|
||||||
if axes is None:
|
|
||||||
axes = F.make_range(ndim)
|
|
||||||
return axes
|
|
||||||
raise ValueError('duplicate value in "axis"')
|
raise ValueError('duplicate value in "axis"')
|
||||||
|
return axes
|
||||||
|
_check_axis_in_range(axes, ndim)
|
||||||
|
return (axes % ndim,)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
Loading…
Reference in New Issue