!18985 numpy-native fix norm, linspace

Merge pull request !18985 from huangmengxi/numpy_fix
This commit is contained in:
i-robot 2021-06-30 02:49:09 +00:00 committed by Gitee
commit 2fb9921b6b
4 changed files with 30 additions and 11 deletions

View File

@ -518,13 +518,21 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
delta = None
if num > 1:
delta = (stop - start) / div
# This is similar to how numpy and jax compute linspace
start_expand = reshape(start, bounds_shape)
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
delta_expand = reshape(delta, bounds_shape)
start_expand, incremental_expand, delta_expand = broadcast_arrays(
start_expand, incremental_expand, delta_expand)
out = start_expand + (incremental_expand * delta_expand)
# This is similar to how numpy and jax compute linspace
if dtype in (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64):
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
delta_expand = reshape(delta, bounds_shape)
start_expand, incremental_expand, delta_expand = broadcast_arrays(
start_expand, incremental_expand, delta_expand)
out = start_expand + (incremental_expand * delta_expand)
else:
stop_expand = reshape(stop, bounds_shape)
step = reshape(_iota(mstype.float32, num), iota_shape) / div
start_expand, stop_expand, step = broadcast_arrays(
start_expand, stop_expand, step)
out = start_expand * (1 - step) + stop_expand * step
elif num == 1:
delta = nan if endpoint else stop - start
out = reshape(start, bounds_shape)
@ -2098,6 +2106,8 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable
start = F.reduce_min(a)
end = F.reduce_max(a)
else:
if not isinstance(range, (list, tuple)) or len(range) != 2:
_raise_value_error('`range` should take the form (start, end)')
start, end = range
if start > end:
_raise_value_error('max must be larger than min in range parameter')

View File

@ -2095,7 +2095,7 @@ def _get_grid(shape):
def choose(a, choices, mode='clip'):
"""
Construct an array from an index array and a list of arrays to choose from.
Given an index array `a`` of integers and a sequence of n arrays (choices),
Given an index array `a` of integers and a sequence of n arrays (choices),
`a` and each choice array are first broadcast, as necessary, to arrays of a
common shape; calling these `Ba` and `Bchoices[i], i = 0,,n-1` we have that,
necessarily, ``Ba.shape == Bchoices[i].shape`` for each `i`. Then, a new array

View File

@ -41,7 +41,7 @@ from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \
_check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \
_make_tensor, _promote_for_trigonometric, _raise_runtime_error, _max, _type_convert, \
_raise_unimplemented_error, _abs, _in, _tuple_slice
_raise_unimplemented_error, _abs, _in, _tuple_slice, _check_is_inf
from .utils import _expand, _broadcast_to, _broadcast_to_shape, _check_input_tensor, \
_to_tensor, _to_tensor_origin_dtype, _isnan
@ -2163,6 +2163,8 @@ def lcm(x1, x2, dtype=None):
q1 = F.tensor_div(x1, common_divisor)
q2 = F.tensor_div(x2, common_divisor)
res = F.tensor_mul(F.tensor_mul(q1, q2), common_divisor)
has_zero = F.equal(multiply(x1, x2), ZERO_TENSOR)
res = where_(has_zero, ZERO_TENSOR, res)
return F.absolute(res).astype(dtype)
return _apply_tensor_op(_lcm, x1, x2, dtype=dtype)
@ -5508,13 +5510,13 @@ def _matrix_norm(x, ord, axis, keepdims): # pylint: disable=redefined-builtin
else:
axis0, axis1 = axis
if not keepdims:
if _abs(ord) == inf and axis0 > axis1:
if _check_is_inf(_abs(ord)) and axis0 > axis1:
axis0 -= 1
elif _abs(ord) == 1 and axis1 > axis0:
axis1 -= 1
if ord == inf:
if _check_is_inf(ord):
res = P.ReduceMax(keepdims)(P.ReduceSum(keepdims)(absolute(x), axis1), axis0)
elif ord == -inf:
elif _check_is_inf(ord, True):
res = P.ReduceMin(keepdims)(P.ReduceSum(keepdims)(absolute(x), axis1), axis0)
elif ord == 1:
res = P.ReduceMax(keepdims)(P.ReduceSum(keepdims)(absolute(x), axis0), axis1)

View File

@ -509,3 +509,10 @@ def _in(x, y):
def _callable_const(x):
"""Returns true if x is a function in graph mode."""
return isinstance(x, typing.Function)
@constexpr
def _check_is_inf(x, negative=False):
if not negative:
return x == float('inf')
return x == float('-inf')