forked from mindspore-Ecosystem/mindspore
!16535 fix code_check warnings
From: @yanglf1121 Reviewed-by: @guoqi1024,@kingxian Signed-off-by: @kingxian
This commit is contained in:
commit
6bbbfe32a5
|
@ -219,8 +219,7 @@ def asfarray(a, dtype=mstype.float32):
|
||||||
return asarray(a)
|
return asarray(a)
|
||||||
|
|
||||||
dtype = _check_dtype(dtype)
|
dtype = _check_dtype(dtype)
|
||||||
# pylint: disable=consider-using-in
|
if dtype not in (mstype.float16, mstype.float32, mstype.float64):
|
||||||
if dtype != mstype.float16 and dtype != mstype.float32 and dtype != mstype.float64:
|
|
||||||
dtype = mstype.float32
|
dtype = mstype.float32
|
||||||
|
|
||||||
if isinstance(a, Tensor):
|
if isinstance(a, Tensor):
|
||||||
|
@ -436,7 +435,7 @@ def arange(start, stop=None, step=None, dtype=None):
|
||||||
return out.astype(dtype)
|
return out.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis):
|
def _type_checking_for_xspace(start, stop, num, endpoint, dtype):
|
||||||
"""utility parameter checking function for linspace, logspace, geomspace."""
|
"""utility parameter checking function for linspace, logspace, geomspace."""
|
||||||
if not isinstance(start, ARRAY_TYPES):
|
if not isinstance(start, ARRAY_TYPES):
|
||||||
_raise_type_error("start should be int, float, bool, list, tuple, Tensor, but got", start)
|
_raise_type_error("start should be int, float, bool, list, tuple, Tensor, but got", start)
|
||||||
|
@ -455,8 +454,18 @@ def _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis):
|
||||||
else:
|
else:
|
||||||
dtype = mstype.float32
|
dtype = mstype.float32
|
||||||
start, stop = broadcast_arrays(start, stop)
|
start, stop = broadcast_arrays(start, stop)
|
||||||
axis = _canonicalize_axis(axis, start.ndim+1)
|
return start, stop, num, endpoint, dtype
|
||||||
return start, stop, num, endpoint, dtype, axis
|
|
||||||
|
|
||||||
|
def _compute_shapes(start, axis, num, endpoint):
|
||||||
|
"""Computes shapes for local variables for np.linspace"""
|
||||||
|
bounds_shape = start.shape
|
||||||
|
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
|
||||||
|
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
|
||||||
|
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
|
||||||
|
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
|
||||||
|
div = (num_tensor - 1) if endpoint else num_tensor
|
||||||
|
return bounds_shape, iota_shape, div
|
||||||
|
|
||||||
|
|
||||||
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
|
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
|
||||||
|
@ -500,15 +509,11 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
||||||
[0. 1. 2. 3. 4. 5.]
|
[0. 1. 2. 3. 4. 5.]
|
||||||
"""
|
"""
|
||||||
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
|
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
|
||||||
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
|
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
|
||||||
|
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||||
if not isinstance(retstep, bool):
|
if not isinstance(retstep, bool):
|
||||||
_raise_type_error("retstep should be an boolean, but got ", retstep)
|
_raise_type_error("retstep should be an boolean, but got ", retstep)
|
||||||
bounds_shape = start.shape
|
bounds_shape, iota_shape, div = _compute_shapes(start, axis, num, endpoint)
|
||||||
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
|
|
||||||
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
|
|
||||||
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
|
|
||||||
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
|
|
||||||
div = (num_tensor - 1) if endpoint else num_tensor
|
|
||||||
out = None
|
out = None
|
||||||
delta = None
|
delta = None
|
||||||
if num > 1:
|
if num > 1:
|
||||||
|
@ -572,7 +577,8 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
|
||||||
[ 1. 2. 4. 8. 16. 32.]
|
[ 1. 2. 4. 8. 16. 32.]
|
||||||
"""
|
"""
|
||||||
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
|
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
|
||||||
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
|
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
|
||||||
|
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||||
if not isinstance(base, (int, float, bool)):
|
if not isinstance(base, (int, float, bool)):
|
||||||
_raise_type_error("base should be a number, but got ", base)
|
_raise_type_error("base should be a number, but got ", base)
|
||||||
linspace_res = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis)
|
linspace_res = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis)
|
||||||
|
@ -620,7 +626,8 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[ 1. 2. 4. 8. 16. 32. 64. 128.]
|
[ 1. 2. 4. 8. 16. 32. 64. 128.]
|
||||||
"""
|
"""
|
||||||
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
|
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
|
||||||
|
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||||
root = num
|
root = num
|
||||||
if endpoint:
|
if endpoint:
|
||||||
root -= 1
|
root -= 1
|
||||||
|
@ -728,6 +735,7 @@ def identity(n, dtype=mstype.float32):
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def empty_compile(dtype, shape):
|
def empty_compile(dtype, shape):
|
||||||
|
"""Returns an empty Tensor."""
|
||||||
return Tensor_(dtype, shape)
|
return Tensor_(dtype, shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1183,9 +1191,9 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None):
|
||||||
return a.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
|
return a.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def _index(i, size, Cartesian=True):
|
def _index(i, size, cartesian=True):
|
||||||
"""If Cartesian=True, index 0 is swapped with index 1."""
|
"""If cartesian=True, index 0 is swapped with index 1."""
|
||||||
if Cartesian:
|
if cartesian:
|
||||||
if i == 1:
|
if i == 1:
|
||||||
return 0
|
return 0
|
||||||
if i == 0 and size >= 2:
|
if i == 0 and size >= 2:
|
||||||
|
@ -1272,12 +1280,12 @@ def meshgrid(*xi, sparse=False, indexing='xy'):
|
||||||
Cartesian = indexing == 'xy'
|
Cartesian = indexing == 'xy'
|
||||||
shape_out = ()
|
shape_out = ()
|
||||||
for i in range(len(grids)):
|
for i in range(len(grids)):
|
||||||
grid_index = _index(i, ndim, Cartesian=Cartesian)
|
grid_index = _index(i, ndim, cartesian=Cartesian)
|
||||||
shape_out += (F.shape(grids[grid_index])[0],)
|
shape_out += (F.shape(grids[grid_index])[0],)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for i, x in enumerate(grids):
|
for i, x in enumerate(grids):
|
||||||
grid_index = _index(i, ndim, Cartesian=Cartesian)
|
grid_index = _index(i, ndim, cartesian=Cartesian)
|
||||||
shape_expanded = _expanded_shape(ndim, shape_out[grid_index], grid_index)
|
shape_expanded = _expanded_shape(ndim, shape_out[grid_index], grid_index)
|
||||||
x = x.reshape(shape_expanded)
|
x = x.reshape(shape_expanded)
|
||||||
if not sparse:
|
if not sparse:
|
||||||
|
@ -1341,7 +1349,7 @@ class nd_grid:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class mGridClass(nd_grid):
|
class MGridClass(nd_grid):
|
||||||
"""
|
"""
|
||||||
mgrid is an :class:`nd_grid` instance with ``sparse=False``.
|
mgrid is an :class:`nd_grid` instance with ``sparse=False``.
|
||||||
|
|
||||||
|
@ -1386,10 +1394,10 @@ class mGridClass(nd_grid):
|
||||||
[-1. -0.5 0. 0.5 1. ]
|
[-1. -0.5 0. 0.5 1. ]
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(mGridClass, self).__init__(sparse=False)
|
super(MGridClass, self).__init__(sparse=False)
|
||||||
|
|
||||||
|
|
||||||
class oGridClass(nd_grid):
|
class OGridClass(nd_grid):
|
||||||
"""
|
"""
|
||||||
ogrid is an :class:`nd_grid` instance with ``sparse=True``.
|
ogrid is an :class:`nd_grid` instance with ``sparse=True``.
|
||||||
|
|
||||||
|
@ -1428,13 +1436,13 @@ class oGridClass(nd_grid):
|
||||||
[-1. -0.5 0. 0.5 1. ]
|
[-1. -0.5 0. 0.5 1. ]
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(oGridClass, self).__init__(sparse=True)
|
super(OGridClass, self).__init__(sparse=True)
|
||||||
|
|
||||||
|
|
||||||
mgrid = mGridClass()
|
mgrid = MGridClass()
|
||||||
|
|
||||||
|
|
||||||
ogrid = oGridClass()
|
ogrid = OGridClass()
|
||||||
|
|
||||||
|
|
||||||
def diag(v, k=0):
|
def diag(v, k=0):
|
||||||
|
@ -1635,7 +1643,6 @@ def ix_(*args):
|
||||||
[1]]), Tensor(shape=[1, 2], dtype=Int32, value=
|
[1]]), Tensor(shape=[1, 2], dtype=Int32, value=
|
||||||
[[2, 4]]))
|
[[2, 4]]))
|
||||||
"""
|
"""
|
||||||
# TODO boolean mask
|
|
||||||
_check_input_tensor(*args)
|
_check_input_tensor(*args)
|
||||||
ndim = len(args)
|
ndim = len(args)
|
||||||
res = ()
|
res = ()
|
||||||
|
|
|
@ -50,7 +50,7 @@ ZERO_TENSOR = asarray_const(0)
|
||||||
|
|
||||||
_mean_keepdims = P.ReduceMean(True)
|
_mean_keepdims = P.ReduceMean(True)
|
||||||
_matmul = P.MatMul(False, False)
|
_matmul = P.MatMul(False, False)
|
||||||
_matmul_T = P.MatMul(False, True)
|
_matmul_t = P.MatMul(False, True)
|
||||||
_reduce_sum_default = P.ReduceSum()
|
_reduce_sum_default = P.ReduceSum()
|
||||||
_reduce_sum_keepdims = P.ReduceSum(True)
|
_reduce_sum_keepdims = P.ReduceSum(True)
|
||||||
_reduce_min_default = P.ReduceMin()
|
_reduce_min_default = P.ReduceMin()
|
||||||
|
@ -63,6 +63,7 @@ _cumprod_default = P.CumProd()
|
||||||
_round = P.Round()
|
_round = P.Round()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def absolute(x, dtype=None):
|
def absolute(x, dtype=None):
|
||||||
"""
|
"""
|
||||||
Calculates the absolute value element-wise.
|
Calculates the absolute value element-wise.
|
||||||
|
@ -669,7 +670,7 @@ def inner(a, b):
|
||||||
a_aligned = F.reshape(a, aligned_shape_a)
|
a_aligned = F.reshape(a, aligned_shape_a)
|
||||||
b_aligned = F.reshape(b, aligned_shape_b)
|
b_aligned = F.reshape(b, aligned_shape_b)
|
||||||
|
|
||||||
res = _matmul_T(a_aligned, b_aligned)
|
res = _matmul_t(a_aligned, b_aligned)
|
||||||
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
|
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -733,7 +734,7 @@ def dot(a, b):
|
||||||
a_aligned = F.reshape(a, (-1, F.shape(a)[-1]))
|
a_aligned = F.reshape(a, (-1, F.shape(a)[-1]))
|
||||||
b_aligned = F.reshape(b, (-1, F.shape(b)[-1]))
|
b_aligned = F.reshape(b, (-1, F.shape(b)[-1]))
|
||||||
|
|
||||||
res = _matmul_T(a_aligned, b_aligned)
|
res = _matmul_t(a_aligned, b_aligned)
|
||||||
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
|
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -1008,21 +1009,10 @@ def average(x, axis=None, weights=None, returned=False):
|
||||||
|
|
||||||
x_avg = full((), nan, F.dtype(x))
|
x_avg = full((), nan, F.dtype(x))
|
||||||
sum_of_weights = None
|
sum_of_weights = None
|
||||||
|
|
||||||
if weights is None:
|
if weights is None:
|
||||||
x_avg = mean(x, axis)
|
x_avg = mean(x, axis)
|
||||||
if axis is None:
|
sum_of_weights = compute_weights_for_mean(x, x_avg, axis)
|
||||||
sum_of_weights = full((), x.size, F.dtype(x))
|
|
||||||
else:
|
|
||||||
fill_value = 1
|
|
||||||
if isinstance(axis, int) or (isinstance(axis, tuple) and F.tuple_len(axis) == 1):
|
|
||||||
fill_value = x.shape[axis] if isinstance(axis, int) else x.shape[axis[0]]
|
|
||||||
elif axis is None:
|
|
||||||
for sh in x.shape:
|
|
||||||
fill_value *= sh
|
|
||||||
else:
|
|
||||||
for ax in axis:
|
|
||||||
fill_value *= x.shape[ax]
|
|
||||||
sum_of_weights = full_like(x_avg, fill_value, F.dtype(x))
|
|
||||||
else:
|
else:
|
||||||
_check_input_tensor(weights)
|
_check_input_tensor(weights)
|
||||||
if x.shape == weights.shape:
|
if x.shape == weights.shape:
|
||||||
|
@ -1043,6 +1033,24 @@ def average(x, axis=None, weights=None, returned=False):
|
||||||
return x_avg
|
return x_avg
|
||||||
|
|
||||||
|
|
||||||
|
def compute_weights_for_mean(x, x_avg, axis):
|
||||||
|
"""Computes weights for np.average."""
|
||||||
|
if axis is None:
|
||||||
|
sum_of_weights = full((), x.size, F.dtype(x))
|
||||||
|
else:
|
||||||
|
fill_value = 1
|
||||||
|
if isinstance(axis, int) or (isinstance(axis, tuple) and F.tuple_len(axis) == 1):
|
||||||
|
fill_value = x.shape[axis] if isinstance(axis, int) else x.shape[axis[0]]
|
||||||
|
elif axis is None:
|
||||||
|
for sh in x.shape:
|
||||||
|
fill_value *= sh
|
||||||
|
else:
|
||||||
|
for ax in axis:
|
||||||
|
fill_value *= x.shape[ax]
|
||||||
|
sum_of_weights = full_like(x_avg, fill_value, F.dtype(x))
|
||||||
|
return sum_of_weights
|
||||||
|
|
||||||
|
|
||||||
def comput_avg(x, axis, weights):
|
def comput_avg(x, axis, weights):
|
||||||
"""Computes average value of input x with given parameters."""
|
"""Computes average value of input x with given parameters."""
|
||||||
axis = () if axis is None else axis
|
axis = () if axis is None else axis
|
||||||
|
@ -1578,7 +1586,7 @@ def floor_divide(x1, x2, dtype=None):
|
||||||
return _apply_tensor_op(F.tensor_floordiv, x1, x2, dtype=dtype)
|
return _apply_tensor_op(F.tensor_floordiv, x1, x2, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def _remainder(x1, x2, C_style=False):
|
def _remainder(x1, x2, c_style=False):
|
||||||
"""Computes remainder without applying keyword arguments."""
|
"""Computes remainder without applying keyword arguments."""
|
||||||
dtype = _promote(F.dtype(x1), F.dtype(x2))
|
dtype = _promote(F.dtype(x1), F.dtype(x2))
|
||||||
if not _check_is_float(dtype):
|
if not _check_is_float(dtype):
|
||||||
|
@ -1586,7 +1594,7 @@ def _remainder(x1, x2, C_style=False):
|
||||||
x2 = F.cast(x2, mstype.float32)
|
x2 = F.cast(x2, mstype.float32)
|
||||||
|
|
||||||
quotient = F.tensor_div(x1, x2)
|
quotient = F.tensor_div(x1, x2)
|
||||||
if C_style:
|
if c_style:
|
||||||
quotient = fix(quotient)
|
quotient = fix(quotient)
|
||||||
else:
|
else:
|
||||||
quotient = F.floor(quotient)
|
quotient = F.floor(quotient)
|
||||||
|
@ -1671,7 +1679,7 @@ def fix(x):
|
||||||
if not _check_is_float(F.dtype(x)):
|
if not _check_is_float(F.dtype(x)):
|
||||||
x = F.cast(x, mstype.float32)
|
x = F.cast(x, mstype.float32)
|
||||||
floored = F.floor(x)
|
floored = F.floor(x)
|
||||||
# TODO change to F.ceil once supported on CPU.
|
# change to F.ceil once supported on CPU.
|
||||||
ceiled = F.neg_tensor(F.floor(F.neg_tensor(x)))
|
ceiled = F.neg_tensor(F.floor(F.neg_tensor(x)))
|
||||||
is_neg = F.tensor_lt(x, zeros(F.shape(x), F.dtype(x)))
|
is_neg = F.tensor_lt(x, zeros(F.shape(x), F.dtype(x)))
|
||||||
return F.select(is_neg, ceiled, floored)
|
return F.select(is_neg, ceiled, floored)
|
||||||
|
@ -1708,7 +1716,7 @@ def fmod(x1, x2, dtype=None):
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[-1 0 -1 1 0 1]
|
[-1 0 -1 1 0 1]
|
||||||
"""
|
"""
|
||||||
return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, C_style=True), x1, x2, dtype=dtype)
|
return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, c_style=True), x1, x2, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def trunc(x, dtype=None):
|
def trunc(x, dtype=None):
|
||||||
|
@ -1845,6 +1853,19 @@ def divmod_(x1, x2, dtype=None):
|
||||||
return (q, r)
|
return (q, r)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_prepend_append(combined, tensor, additional_tensor, axis):
|
||||||
|
"""Concatenates prepend or append to tensor."""
|
||||||
|
if isinstance(additional_tensor, (int, float, bool)):
|
||||||
|
additional_tensor = asarray_const(additional_tensor)
|
||||||
|
elif not isinstance(additional_tensor, Tensor):
|
||||||
|
_raise_type_error("prepend must be scalar or Tensor, but got ", additional_tensor)
|
||||||
|
additional_shape = tensor.shape
|
||||||
|
additional_shape = _tuple_setitem(additional_shape, axis, 1)
|
||||||
|
additional_tensor = _broadcast_to_shape(additional_tensor, additional_shape)
|
||||||
|
combined += (additional_tensor,)
|
||||||
|
return combined
|
||||||
|
|
||||||
|
|
||||||
def diff(a, n=1, axis=-1, prepend=None, append=None):
|
def diff(a, n=1, axis=-1, prepend=None, append=None):
|
||||||
"""
|
"""
|
||||||
Calculates the n-th discrete difference along the given axis.
|
Calculates the n-th discrete difference along the given axis.
|
||||||
|
@ -1899,26 +1920,12 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
|
||||||
|
|
||||||
combined = ()
|
combined = ()
|
||||||
if prepend is not None:
|
if prepend is not None:
|
||||||
if isinstance(prepend, (int, float, bool)):
|
combined = _handle_prepend_append(combined, a, prepend, axis)
|
||||||
prepend = asarray_const(prepend)
|
|
||||||
elif not isinstance(prepend, Tensor):
|
|
||||||
_raise_type_error("prepend must be scalar or Tensor, but got ", prepend)
|
|
||||||
prepend_shape = a.shape
|
|
||||||
prepend_shape = _tuple_setitem(prepend_shape, axis, 1)
|
|
||||||
prepend = _broadcast_to_shape(prepend, prepend_shape)
|
|
||||||
combined += (prepend,)
|
|
||||||
|
|
||||||
combined += (a,)
|
combined += (a,)
|
||||||
|
|
||||||
if append is not None:
|
if append is not None:
|
||||||
if isinstance(append, (int, float, bool)):
|
combined = _handle_prepend_append(combined, a, append, axis)
|
||||||
append = asarray_const(append)
|
|
||||||
elif not isinstance(append, Tensor):
|
|
||||||
_raise_type_error("append must be scalar or Tensor, but got ", append)
|
|
||||||
append_shape = a.shape
|
|
||||||
append_shape = _tuple_setitem(append_shape, axis, 1)
|
|
||||||
append = _broadcast_to_shape(append, append_shape)
|
|
||||||
combined += (append,)
|
|
||||||
|
|
||||||
if combined:
|
if combined:
|
||||||
a = concatenate(combined, axis)
|
a = concatenate(combined, axis)
|
||||||
|
@ -2239,6 +2246,22 @@ def _handle_inputs(cov_input, rowvar):
|
||||||
return cov_input
|
return cov_input
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_facts(w, m, ddof, aweights):
|
||||||
|
"""Computes facts for np.cov"""
|
||||||
|
fact = None
|
||||||
|
if w is None:
|
||||||
|
fact = m.shape[1] - ddof
|
||||||
|
else:
|
||||||
|
w_sum = _reduce_sum_default(w, -1)
|
||||||
|
if ddof == 0:
|
||||||
|
fact = w_sum
|
||||||
|
elif aweights is None:
|
||||||
|
fact = w_sum - ddof
|
||||||
|
else:
|
||||||
|
fact = w_sum - ddof * F.reduce_sum(w * aweights) / w_sum
|
||||||
|
return fact
|
||||||
|
|
||||||
|
|
||||||
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None):
|
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None):
|
||||||
"""
|
"""
|
||||||
Estimates a covariance matrix, given data and weights.
|
Estimates a covariance matrix, given data and weights.
|
||||||
|
@ -2328,23 +2351,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=N
|
||||||
avg = average(m, axis=1, weights=w)
|
avg = average(m, axis=1, weights=w)
|
||||||
|
|
||||||
# Determine the Normalization
|
# Determine the Normalization
|
||||||
if w is None:
|
fact = _handle_facts(w, m, ddof, aweights)
|
||||||
fact = m.shape[1] - ddof
|
|
||||||
else:
|
|
||||||
w_sum = _reduce_sum_default(w, -1)
|
|
||||||
if ddof == 0:
|
|
||||||
fact = w_sum
|
|
||||||
elif aweights is None:
|
|
||||||
fact = w_sum - ddof
|
|
||||||
else:
|
|
||||||
fact = w_sum - ddof * F.reduce_sum(w * aweights) / w_sum
|
|
||||||
|
|
||||||
m = m - F.expand_dims(avg, -1)
|
m = m - F.expand_dims(avg, -1)
|
||||||
if w is None:
|
if w is None:
|
||||||
m_T = m.T
|
m_t = m.T
|
||||||
else:
|
else:
|
||||||
m_T = (m * w).T
|
m_t = (m * w).T
|
||||||
res = true_divide(dot(m, m_T), fact).squeeze()
|
res = true_divide(dot(m, m_t), fact).squeeze()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
return res.astype(dtype)
|
return res.astype(dtype)
|
||||||
return res
|
return res
|
||||||
|
@ -2417,7 +2431,7 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
||||||
if cmp_fn is None:
|
if cmp_fn is None:
|
||||||
initial = nan
|
initial = nan
|
||||||
else:
|
else:
|
||||||
return _raise_value_error('initial value must be provided for zero-size arrays')
|
_raise_value_error('initial value must be provided for zero-size arrays')
|
||||||
return full(shape_out, initial, dtype)
|
return full(shape_out, initial, dtype)
|
||||||
|
|
||||||
if initial is not None:
|
if initial is not None:
|
||||||
|
@ -2426,7 +2440,7 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
||||||
|
|
||||||
if isinstance(where, Tensor):
|
if isinstance(where, Tensor):
|
||||||
if initial is None:
|
if initial is None:
|
||||||
return _raise_value_error('initial value must be provided for where masks')
|
_raise_value_error('initial value must be provided for where masks')
|
||||||
ndim_orig = F.rank(a)
|
ndim_orig = F.rank(a)
|
||||||
a = where_(where, a, initial)
|
a = where_(where, a, initial)
|
||||||
axes = _real_axes(ndim_orig, F.rank(a), axes)
|
axes = _real_axes(ndim_orig, F.rank(a), axes)
|
||||||
|
@ -3277,8 +3291,10 @@ def log2(x, dtype=None):
|
||||||
[1. 2. 3.]
|
[1. 2. 3.]
|
||||||
"""
|
"""
|
||||||
tensor_2 = _make_tensor(2, x.dtype)
|
tensor_2 = _make_tensor(2, x.dtype)
|
||||||
|
|
||||||
def _log2(x):
|
def _log2(x):
|
||||||
return F.log(x) / F.log(tensor_2)
|
return F.log(x) / F.log(tensor_2)
|
||||||
|
|
||||||
return _apply_tensor_op(_log2, x, dtype=dtype)
|
return _apply_tensor_op(_log2, x, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -311,9 +311,12 @@ def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
type_str = ""
|
type_str = ""
|
||||||
if type_int: type_str += "int, "
|
if type_int:
|
||||||
if type_tuple: type_str += "tuple, "
|
type_str += "int, "
|
||||||
if type_list: type_str += "list, "
|
if type_tuple:
|
||||||
|
type_str += "tuple, "
|
||||||
|
if type_list:
|
||||||
|
type_str += "list, "
|
||||||
raise TypeError(f"Axis should be {type_str}but got {type(axis)}.")
|
raise TypeError(f"Axis should be {type_str}but got {type(axis)}.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -451,7 +454,7 @@ def _tuple_setitem(tup, idx, value):
|
||||||
@constexpr
|
@constexpr
|
||||||
def _iota(dtype, num, increasing=True):
|
def _iota(dtype, num, increasing=True):
|
||||||
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
|
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
|
||||||
# TODO: Change to P.Linspace when the kernel is implemented on CPU.
|
# Change to P.Linspace when the kernel is implemented on CPU.
|
||||||
if num <= 0:
|
if num <= 0:
|
||||||
raise ValueError("zero shape Tensor is not currently supported.")
|
raise ValueError("zero shape Tensor is not currently supported.")
|
||||||
if increasing:
|
if increasing:
|
||||||
|
|
Loading…
Reference in New Issue