additional typecheck

This commit is contained in:
huangmengxi 2021-03-02 19:26:40 +08:00
parent 6f834db9b2
commit f34756c3bd
3 changed files with 16 additions and 10 deletions

View File

@ -1200,6 +1200,8 @@ def moveaxis(a, source, destination):
ndim = F.rank(a) ndim = F.rank(a)
source = _check_axis_valid(source, ndim) source = _check_axis_valid(source, ndim)
destination = _check_axis_valid(destination, ndim) destination = _check_axis_valid(destination, ndim)
if len(source) != len(destination):
_raise_value_error('`source` and `destination` arguments must have the same number of elements')
perm = _get_moved_perm(ndim, source, destination) perm = _get_moved_perm(ndim, source, destination)
shape = F.shape(a) shape = F.shape(a)
@ -1305,7 +1307,7 @@ def broadcast_to(array, shape):
""" """
shape_a = F.shape(array) shape_a = F.shape(array)
if not _check_can_broadcast_to(shape_a, shape): if not _check_can_broadcast_to(shape_a, shape):
return _raise_value_error('cannot broadcaast with {shape_a} {shape}') return _raise_value_error('cannot broadcast with ', shape)
return _broadcast_to_shape(array, shape) return _broadcast_to_shape(array, shape)
@ -1386,6 +1388,7 @@ def split(x, indices_or_sections, axis=0):
Tensor(shape=[3], dtype=Float32, Tensor(shape=[3], dtype=Float32,
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
""" """
_check_input_tensor(x)
_ = _check_axis_type(axis, True, False, False) _ = _check_axis_type(axis, True, False, False)
axis = _canonicalize_axis(axis, x.ndim) axis = _canonicalize_axis(axis, x.ndim)
res = None res = None
@ -1827,6 +1830,8 @@ def take(a, indices, axis=None, mode='raise'):
[5 7]] [5 7]]
""" """
_check_input_tensor(a, indices) _check_input_tensor(a, indices)
if mode not in ('raise', 'wrap', 'clip'):
_raise_value_error('raise should be one of "raise", "wrap", or "clip"')
if axis is None: if axis is None:
a = ravel(a) a = ravel(a)
axis = 0 axis = 0

View File

@ -1149,11 +1149,13 @@ def ptp(x, axis=None, out=None, keepdims=False):
[2. 0. 5. 2.] [2. 0. 5. 2.]
""" """
_check_input_tensor(x) _check_input_tensor(x)
if not isinstance(keepdims, bool):
_raise_type_error('keepdims should be boolean')
if axis is None: if axis is None:
axis = () axis = ()
else: else:
_check_axis_type(axis, True, True, False) _check_axis_type(axis, True, True, False)
axis = _canonicalize_axis(axis, x.ndim) axis = _check_axis_valid(axis, x.ndim)
if keepdims: if keepdims:
x_min = _reduce_min_keepdims(x, axis) x_min = _reduce_min_keepdims(x, axis)

View File

@ -165,19 +165,18 @@ def _check_axis_valid(axes, ndim):
Checks axes are valid given ndim, and returns axes that can be passed Checks axes are valid given ndim, and returns axes that can be passed
to the built-in operator (non-negative, int or tuple) to the built-in operator (non-negative, int or tuple)
""" """
if isinstance(axes, int): if axes is None:
_check_axis_in_range(axes, ndim) axes = F.make_range(ndim)
return (axes % ndim,) return axes
if isinstance(axes, (tuple, list)): if isinstance(axes, (tuple, list)):
for axis in axes: for axis in axes:
_check_axis_in_range(axis, ndim) _check_axis_in_range(axis, ndim)
axes = tuple(map(lambda x: x % ndim, axes)) axes = tuple(map(lambda x: x % ndim, axes))
if all(axes.count(el) <= 1 for el in axes): if any(axes.count(el) > 1 for el in axes):
return axes
if axes is None:
axes = F.make_range(ndim)
return axes
raise ValueError('duplicate value in "axis"') raise ValueError('duplicate value in "axis"')
return axes
_check_axis_in_range(axes, ndim)
return (axes % ndim,)
@constexpr @constexpr