fix
This commit is contained in:
parent
70841afa9d
commit
7f2eb50abb
|
@ -518,13 +518,21 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
|||
delta = None
|
||||
if num > 1:
|
||||
delta = (stop - start) / div
|
||||
# This is similar to how numpy and jax compute linspace
|
||||
start_expand = reshape(start, bounds_shape)
|
||||
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
|
||||
delta_expand = reshape(delta, bounds_shape)
|
||||
start_expand, incremental_expand, delta_expand = broadcast_arrays(
|
||||
start_expand, incremental_expand, delta_expand)
|
||||
out = start_expand + (incremental_expand * delta_expand)
|
||||
# This is similar to how numpy and jax compute linspace
|
||||
if dtype in (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64):
|
||||
incremental_expand = reshape(_iota(mstype.float32, num), iota_shape)
|
||||
delta_expand = reshape(delta, bounds_shape)
|
||||
start_expand, incremental_expand, delta_expand = broadcast_arrays(
|
||||
start_expand, incremental_expand, delta_expand)
|
||||
out = start_expand + (incremental_expand * delta_expand)
|
||||
else:
|
||||
stop_expand = reshape(stop, bounds_shape)
|
||||
step = reshape(_iota(mstype.float32, num), iota_shape) / div
|
||||
start_expand, stop_expand, step = broadcast_arrays(
|
||||
start_expand, stop_expand, step)
|
||||
out = start_expand * (1 - step) + stop_expand * step
|
||||
elif num == 1:
|
||||
delta = nan if endpoint else stop - start
|
||||
out = reshape(start, bounds_shape)
|
||||
|
@ -2098,6 +2106,8 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable
|
|||
start = F.reduce_min(a)
|
||||
end = F.reduce_max(a)
|
||||
else:
|
||||
if not isinstance(range, (list, tuple)) or len(range) != 2:
|
||||
_raise_value_error('`range` should take the form (start, end)')
|
||||
start, end = range
|
||||
if start > end:
|
||||
_raise_value_error('max must be larger than min in range parameter')
|
||||
|
|
|
@ -2095,7 +2095,7 @@ def _get_grid(shape):
|
|||
def choose(a, choices, mode='clip'):
|
||||
"""
|
||||
Construct an array from an index array and a list of arrays to choose from.
|
||||
Given an “index” array `a`` of integers and a sequence of n arrays (choices),
|
||||
Given an “index” array `a` of integers and a sequence of n arrays (choices),
|
||||
`a` and each choice array are first broadcast, as necessary, to arrays of a
|
||||
common shape; calling these `Ba` and `Bchoices[i], i = 0,…,n-1` we have that,
|
||||
necessarily, ``Ba.shape == Bchoices[i].shape`` for each `i`. Then, a new array
|
||||
|
|
|
@ -41,7 +41,7 @@ from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
|||
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \
|
||||
_check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \
|
||||
_make_tensor, _promote_for_trigonometric, _raise_runtime_error, _max, _type_convert, \
|
||||
_raise_unimplemented_error, _abs, _in, _tuple_slice
|
||||
_raise_unimplemented_error, _abs, _in, _tuple_slice, _check_is_inf
|
||||
from .utils import _expand, _broadcast_to, _broadcast_to_shape, _check_input_tensor, \
|
||||
_to_tensor, _to_tensor_origin_dtype, _isnan
|
||||
|
||||
|
@ -2156,6 +2156,8 @@ def lcm(x1, x2, dtype=None):
|
|||
q1 = F.tensor_div(x1, common_divisor)
|
||||
q2 = F.tensor_div(x2, common_divisor)
|
||||
res = F.tensor_mul(F.tensor_mul(q1, q2), common_divisor)
|
||||
has_zero = F.equal(multiply(x1, x2), ZERO_TENSOR)
|
||||
res = where_(has_zero, ZERO_TENSOR, res)
|
||||
return F.absolute(res).astype(dtype)
|
||||
|
||||
return _apply_tensor_op(_lcm, x1, x2, dtype=dtype)
|
||||
|
@ -5501,13 +5503,13 @@ def _matrix_norm(x, ord, axis, keepdims): # pylint: disable=redefined-builtin
|
|||
else:
|
||||
axis0, axis1 = axis
|
||||
if not keepdims:
|
||||
if _abs(ord) == inf and axis0 > axis1:
|
||||
if _check_is_inf(_abs(ord)) and axis0 > axis1:
|
||||
axis0 -= 1
|
||||
elif _abs(ord) == 1 and axis1 > axis0:
|
||||
axis1 -= 1
|
||||
if ord == inf:
|
||||
if _check_is_inf(ord):
|
||||
res = P.ReduceMax(keepdims)(P.ReduceSum(keepdims)(absolute(x), axis1), axis0)
|
||||
elif ord == -inf:
|
||||
elif _check_is_inf(ord, True):
|
||||
res = P.ReduceMin(keepdims)(P.ReduceSum(keepdims)(absolute(x), axis1), axis0)
|
||||
elif ord == 1:
|
||||
res = P.ReduceMax(keepdims)(P.ReduceSum(keepdims)(absolute(x), axis0), axis1)
|
||||
|
|
|
@ -509,3 +509,10 @@ def _in(x, y):
|
|||
def _callable_const(x):
|
||||
"""Returns true if x is a function in graph mode."""
|
||||
return isinstance(x, typing.Function)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_inf(x, negative=False):
|
||||
if not negative:
|
||||
return x == float('inf')
|
||||
return x == float('-inf')
|
||||
|
|
Loading…
Reference in New Issue