forked from mindspore-Ecosystem/mindspore
additional typecheck
This commit is contained in:
parent
6f834db9b2
commit
f34756c3bd
|
@ -1200,6 +1200,8 @@ def moveaxis(a, source, destination):
|
|||
ndim = F.rank(a)
|
||||
source = _check_axis_valid(source, ndim)
|
||||
destination = _check_axis_valid(destination, ndim)
|
||||
if len(source) != len(destination):
|
||||
_raise_value_error('`source` and `destination` arguments must have the same number of elements')
|
||||
perm = _get_moved_perm(ndim, source, destination)
|
||||
|
||||
shape = F.shape(a)
|
||||
|
@ -1305,7 +1307,7 @@ def broadcast_to(array, shape):
|
|||
"""
|
||||
shape_a = F.shape(array)
|
||||
if not _check_can_broadcast_to(shape_a, shape):
|
||||
return _raise_value_error('cannot broadcaast with {shape_a} {shape}')
|
||||
return _raise_value_error('cannot broadcast with ', shape)
|
||||
return _broadcast_to_shape(array, shape)
|
||||
|
||||
|
||||
|
@ -1386,6 +1388,7 @@ def split(x, indices_or_sections, axis=0):
|
|||
Tensor(shape=[3], dtype=Float32,
|
||||
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
||||
"""
|
||||
_check_input_tensor(x)
|
||||
_ = _check_axis_type(axis, True, False, False)
|
||||
axis = _canonicalize_axis(axis, x.ndim)
|
||||
res = None
|
||||
|
@ -1827,6 +1830,8 @@ def take(a, indices, axis=None, mode='raise'):
|
|||
[5 7]]
|
||||
"""
|
||||
_check_input_tensor(a, indices)
|
||||
if mode not in ('raise', 'wrap', 'clip'):
|
||||
_raise_value_error('raise should be one of "raise", "wrap", or "clip"')
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
|
|
|
@ -1149,11 +1149,13 @@ def ptp(x, axis=None, out=None, keepdims=False):
|
|||
[2. 0. 5. 2.]
|
||||
"""
|
||||
_check_input_tensor(x)
|
||||
if not isinstance(keepdims, bool):
|
||||
_raise_type_error('keepdims should be boolean')
|
||||
if axis is None:
|
||||
axis = ()
|
||||
else:
|
||||
_check_axis_type(axis, True, True, False)
|
||||
axis = _canonicalize_axis(axis, x.ndim)
|
||||
axis = _check_axis_valid(axis, x.ndim)
|
||||
|
||||
if keepdims:
|
||||
x_min = _reduce_min_keepdims(x, axis)
|
||||
|
|
|
@ -165,19 +165,18 @@ def _check_axis_valid(axes, ndim):
|
|||
Checks axes are valid given ndim, and returns axes that can be passed
|
||||
to the built-in operator (non-negative, int or tuple)
|
||||
"""
|
||||
if isinstance(axes, int):
|
||||
_check_axis_in_range(axes, ndim)
|
||||
return (axes % ndim,)
|
||||
if axes is None:
|
||||
axes = F.make_range(ndim)
|
||||
return axes
|
||||
if isinstance(axes, (tuple, list)):
|
||||
for axis in axes:
|
||||
_check_axis_in_range(axis, ndim)
|
||||
axes = tuple(map(lambda x: x % ndim, axes))
|
||||
if all(axes.count(el) <= 1 for el in axes):
|
||||
return axes
|
||||
if axes is None:
|
||||
axes = F.make_range(ndim)
|
||||
return axes
|
||||
if any(axes.count(el) > 1 for el in axes):
|
||||
raise ValueError('duplicate value in "axis"')
|
||||
return axes
|
||||
_check_axis_in_range(axes, ndim)
|
||||
return (axes % ndim,)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
Loading…
Reference in New Issue