forked from mindspore-Ecosystem/mindspore
!13310 numpy-native fix linspace, diag, mximum, minimum
From: @jachua Reviewed-by: @guoqi1024,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
7c4d7a89e1
|
@ -22,15 +22,14 @@ from ..ops.primitive import constexpr
|
|||
from ..nn.layer.basic import tril as nn_tril
|
||||
from ..nn.layer.basic import triu as nn_triu
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._c_expression.typing import Float
|
||||
|
||||
from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \
|
||||
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \
|
||||
_expand
|
||||
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
|
||||
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
|
||||
_raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \
|
||||
_type_convert, _canonicalize_axis, _list_comprehensions, _ceil
|
||||
_raise_type_error, _expanded_shape, _check_is_float, _iota, _type_convert, \
|
||||
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_getitem, _tuple_slice
|
||||
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to
|
||||
from .dtypes import nan
|
||||
|
||||
|
@ -503,9 +502,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
|||
start, stop = broadcast_arrays(start, stop)
|
||||
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||
bounds_shape = start.shape
|
||||
bounds_shape = bounds_shape[:axis] + (1,) + bounds_shape[axis:]
|
||||
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
|
||||
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
|
||||
iota_shape = iota_shape[:axis] + (num,) + iota_shape[axis+1:]
|
||||
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
|
||||
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
|
||||
div = (num_tensor - 1) if endpoint else num_tensor
|
||||
|
||||
|
@ -1542,7 +1541,7 @@ def diag(v, k=0):
|
|||
prod = F.tensor_mul(v, e)
|
||||
|
||||
cast_type = dtype
|
||||
if not isinstance(dtype, Float):
|
||||
if not _check_is_float(dtype):
|
||||
# reduce sum only supports float types
|
||||
cast_type = mstype.float32
|
||||
prod = F.cast(prod, cast_type)
|
||||
|
|
|
@ -1809,10 +1809,7 @@ def _check_indices(size, indices, mode):
|
|||
out_of_lowerbounds = F.tensor_lt(indices, lowerbounds)
|
||||
out_of_upperbounds = F.tensor_gt(indices, upperbounds)
|
||||
if mode == 'raise':
|
||||
# For mode raise, index-out-of-bounds checking is performed at backend since
|
||||
# evaluation of a boolean scalar Tensor always returns true in graph mode
|
||||
# regardless of the truth value contained
|
||||
return indices
|
||||
_raise_unimplemented_error('"raise" mode is not implemented')
|
||||
if mode == 'wrap':
|
||||
return _mod(indices, F.fill(dtype, shape, size))
|
||||
zeros = F.fill(dtype, shape, 0)
|
||||
|
@ -1821,7 +1818,7 @@ def _check_indices(size, indices, mode):
|
|||
return clipped
|
||||
|
||||
|
||||
def take(a, indices, axis=None, mode='raise'):
|
||||
def take(a, indices, axis=None, mode='clip'):
|
||||
"""
|
||||
Takes elements from an array along an axis.
|
||||
|
||||
|
@ -1832,6 +1829,7 @@ def take(a, indices, axis=None, mode='raise'):
|
|||
|
||||
Note:
|
||||
Numpy argument out is not supported.
|
||||
``mode = 'raise'`` is not supported, and the default mode is 'clip' instead.
|
||||
|
||||
Args:
|
||||
a (Tensor): Source array with shape `(Ni…, M, Nk…)`.
|
||||
|
|
|
@ -580,12 +580,12 @@ def minimum(x1, x2, dtype=None):
|
|||
# comparisons with 2 scalars
|
||||
if x1.ndim == 0 and x2.ndim == 0:
|
||||
x1 = expand_dims(x1, 0)
|
||||
return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype).squeeze()
|
||||
return _apply_tensor_op(functools.partial(_prop_nan, F.minimum), x1, x2, dtype=dtype).squeeze()
|
||||
if x1.ndim == 0:
|
||||
dtype = x2.dtype
|
||||
elif x2.ndim == 0:
|
||||
dtype = x1.dtype
|
||||
return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype)
|
||||
return _apply_tensor_op(functools.partial(_prop_nan, F.minimum), x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def mean(a, axis=None, keepdims=False, dtype=None):
|
||||
|
@ -1299,6 +1299,14 @@ def log(x, dtype=None):
|
|||
return _apply_tensor_op(F.log, x, dtype=dtype)
|
||||
|
||||
|
||||
def _prop_nan(fn, x1, x2):
|
||||
"""Selects NaN if either element is NaN"""
|
||||
has_nan = F.logical_or(_isnan(x1), _isnan(x2))
|
||||
nan_tensor = F.fill(_promote(F.dtype(x1), F.dtype(x2)), F.shape(has_nan), nan)
|
||||
res = fn(x1, x2)
|
||||
return F.select(has_nan, nan_tensor, res)
|
||||
|
||||
|
||||
def maximum(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns the element-wise maximum of array elements.
|
||||
|
@ -1349,12 +1357,12 @@ def maximum(x1, x2, dtype=None):
|
|||
# F.maximum does not support when both operands are scalar
|
||||
if x1.ndim == 0 and x2.ndim == 0:
|
||||
x1 = expand_dims(x1, 0)
|
||||
return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype).squeeze()
|
||||
return _apply_tensor_op(functools.partial(_prop_nan, F.maximum), x1, x2, dtype=dtype).squeeze()
|
||||
if x1.ndim == 0:
|
||||
dtype = x2.dtype
|
||||
elif x2.ndim == 0:
|
||||
dtype = x1.dtype
|
||||
return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype)
|
||||
return _apply_tensor_op(functools.partial(_prop_nan, F.maximum), x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def heaviside(x1, x2, dtype=None):
|
||||
|
@ -1567,7 +1575,7 @@ def hypot(x1, x2, dtype=None):
|
|||
[[5. 5. 5.]
|
||||
[5. 5. 5.]
|
||||
[5. 5. 5.]]
|
||||
>>> output = np.hypot(3*np.ones((3, 3)), np.array([4]))
|
||||
>>> output = np.hypot(3*np.ones((3, 3)), np.array([4.0]))
|
||||
>>> print(output)
|
||||
[[5. 5. 5.]
|
||||
[5. 5. 5.]
|
||||
|
|
|
@ -219,6 +219,22 @@ def _raise_runtime_error(info, param=None):
|
|||
raise RuntimeError(info)
|
||||
raise RuntimeError(info + f"{param}")
|
||||
|
||||
|
||||
def _raise_unimplemented_error(info, param=None):
|
||||
"""
|
||||
Raise NotImplementedError in both graph/pynative mode
|
||||
|
||||
Args:
|
||||
info(str): info string to display
|
||||
param(python obj): any object that can be recognized by graph mode. If is
|
||||
not None, then param's value information will be extracted and displayed.
|
||||
Default is None.
|
||||
"""
|
||||
if param is None:
|
||||
raise NotImplementedError(info)
|
||||
raise NotImplementedError(info + f"{param}")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _empty(dtype, shape):
|
||||
"""Returns an uninitialized array with dtype and shape."""
|
||||
|
@ -454,3 +470,8 @@ def _seq_prod(seq1, seq2):
|
|||
def _make_tensor(val, dtype):
|
||||
""" Returns the tensor with value `val` and dtype `dtype`."""
|
||||
return Tensor(val, dtype)
|
||||
|
||||
|
||||
def _tuple_slice(tup, start, end):
|
||||
"""get sliced tuple from start and end."""
|
||||
return tup[start:end]
|
||||
|
|
|
@ -251,6 +251,16 @@ def test_float_power():
|
|||
@pytest.mark.env_onecard
|
||||
def test_minimum():
|
||||
run_binop_test(mnp_minimum, onp_minimum, test_case)
|
||||
x = onp.random.randint(-10, 10, 20).astype(onp.float32)
|
||||
y = onp.random.randint(-10, 10, 20).astype(onp.float32)
|
||||
x[onp.random.randint(0, 10, 3)] = onp.nan
|
||||
y[onp.random.randint(0, 10, 3)] = onp.nan
|
||||
x[onp.random.randint(0, 10, 3)] = onp.NINF
|
||||
y[onp.random.randint(0, 10, 3)] = onp.NINF
|
||||
x[onp.random.randint(0, 10, 3)] = onp.PINF
|
||||
y[onp.random.randint(0, 10, 3)] = onp.PINF
|
||||
match_res(mnp_minimum, onp_minimum, x, y)
|
||||
match_res(mnp_minimum, onp_minimum, y, x)
|
||||
|
||||
|
||||
def mnp_tensordot(x, y):
|
||||
|
@ -924,6 +934,16 @@ def onp_maximum(x1, x2):
|
|||
@pytest.mark.env_onecard
|
||||
def test_maximum():
|
||||
run_binop_test(mnp_maximum, onp_maximum, test_case)
|
||||
x = onp.random.randint(-10, 10, 20).astype(onp.float32)
|
||||
y = onp.random.randint(-10, 10, 20).astype(onp.float32)
|
||||
x[onp.random.randint(0, 10, 3)] = onp.nan
|
||||
y[onp.random.randint(0, 10, 3)] = onp.nan
|
||||
x[onp.random.randint(0, 10, 3)] = onp.NINF
|
||||
y[onp.random.randint(0, 10, 3)] = onp.NINF
|
||||
x[onp.random.randint(0, 10, 3)] = onp.PINF
|
||||
y[onp.random.randint(0, 10, 3)] = onp.PINF
|
||||
match_res(mnp_maximum, onp_maximum, x, y)
|
||||
match_res(mnp_maximum, onp_maximum, y, x)
|
||||
|
||||
|
||||
def mnp_clip(x):
|
||||
|
|
Loading…
Reference in New Issue