!16535 fix code_check warnings

From: @yanglf1121
Reviewed-by: @guoqi1024,@kingxian
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-05-24 10:01:40 +08:00 committed by Gitee
commit 6bbbfe32a5
3 changed files with 107 additions and 81 deletions

View File

@ -219,8 +219,7 @@ def asfarray(a, dtype=mstype.float32):
return asarray(a)
dtype = _check_dtype(dtype)
# pylint: disable=consider-using-in
if dtype != mstype.float16 and dtype != mstype.float32 and dtype != mstype.float64:
if dtype not in (mstype.float16, mstype.float32, mstype.float64):
dtype = mstype.float32
if isinstance(a, Tensor):
@ -436,7 +435,7 @@ def arange(start, stop=None, step=None, dtype=None):
return out.astype(dtype)
def _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis):
def _type_checking_for_xspace(start, stop, num, endpoint, dtype):
"""utility parameter checking function for linspace, logspace, geomspace."""
if not isinstance(start, ARRAY_TYPES):
_raise_type_error("start should be int, float, bool, list, tuple, Tensor, but got", start)
@ -455,8 +454,18 @@ def _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis):
else:
dtype = mstype.float32
start, stop = broadcast_arrays(start, stop)
axis = _canonicalize_axis(axis, start.ndim+1)
return start, stop, num, endpoint, dtype, axis
return start, stop, num, endpoint, dtype
def _compute_shapes(start, axis, num, endpoint):
"""Computes shapes for local variables for np.linspace"""
bounds_shape = start.shape
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
div = (num_tensor - 1) if endpoint else num_tensor
return bounds_shape, iota_shape, div
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
@ -500,15 +509,11 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
[0. 1. 2. 3. 4. 5.]
"""
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
axis = _canonicalize_axis(axis, start.ndim+1)
if not isinstance(retstep, bool):
_raise_type_error("retstep should be an boolean, but got ", retstep)
bounds_shape = start.shape
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
div = (num_tensor - 1) if endpoint else num_tensor
bounds_shape, iota_shape, div = _compute_shapes(start, axis, num, endpoint)
out = None
delta = None
if num > 1:
@ -572,7 +577,8 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
[ 1. 2. 4. 8. 16. 32.]
"""
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
axis = _canonicalize_axis(axis, start.ndim+1)
if not isinstance(base, (int, float, bool)):
_raise_type_error("base should be a number, but got ", base)
linspace_res = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis)
@ -620,7 +626,8 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
>>> print(output)
[ 1. 2. 4. 8. 16. 32. 64. 128.]
"""
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
axis = _canonicalize_axis(axis, start.ndim+1)
root = num
if endpoint:
root -= 1
@ -728,6 +735,7 @@ def identity(n, dtype=mstype.float32):
@constexpr
def empty_compile(dtype, shape):
"""Returns an empty Tensor."""
return Tensor_(dtype, shape)
@ -1183,9 +1191,9 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None):
return a.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
def _index(i, size, Cartesian=True):
"""If Cartesian=True, index 0 is swapped with index 1."""
if Cartesian:
def _index(i, size, cartesian=True):
"""If cartesian=True, index 0 is swapped with index 1."""
if cartesian:
if i == 1:
return 0
if i == 0 and size >= 2:
@ -1272,12 +1280,12 @@ def meshgrid(*xi, sparse=False, indexing='xy'):
Cartesian = indexing == 'xy'
shape_out = ()
for i in range(len(grids)):
grid_index = _index(i, ndim, Cartesian=Cartesian)
grid_index = _index(i, ndim, cartesian=Cartesian)
shape_out += (F.shape(grids[grid_index])[0],)
res = []
for i, x in enumerate(grids):
grid_index = _index(i, ndim, Cartesian=Cartesian)
grid_index = _index(i, ndim, cartesian=Cartesian)
shape_expanded = _expanded_shape(ndim, shape_out[grid_index], grid_index)
x = x.reshape(shape_expanded)
if not sparse:
@ -1341,7 +1349,7 @@ class nd_grid:
return res
class mGridClass(nd_grid):
class MGridClass(nd_grid):
"""
mgrid is an :class:`nd_grid` instance with ``sparse=False``.
@ -1386,10 +1394,10 @@ class mGridClass(nd_grid):
[-1. -0.5 0. 0.5 1. ]
"""
def __init__(self):
super(mGridClass, self).__init__(sparse=False)
super(MGridClass, self).__init__(sparse=False)
class oGridClass(nd_grid):
class OGridClass(nd_grid):
"""
ogrid is an :class:`nd_grid` instance with ``sparse=True``.
@ -1428,13 +1436,13 @@ class oGridClass(nd_grid):
[-1. -0.5 0. 0.5 1. ]
"""
def __init__(self):
super(oGridClass, self).__init__(sparse=True)
super(OGridClass, self).__init__(sparse=True)
mgrid = mGridClass()
mgrid = MGridClass()
ogrid = oGridClass()
ogrid = OGridClass()
def diag(v, k=0):
@ -1635,7 +1643,6 @@ def ix_(*args):
[1]]), Tensor(shape=[1, 2], dtype=Int32, value=
[[2, 4]]))
"""
# TODO boolean mask
_check_input_tensor(*args)
ndim = len(args)
res = ()

View File

@ -50,7 +50,7 @@ ZERO_TENSOR = asarray_const(0)
_mean_keepdims = P.ReduceMean(True)
_matmul = P.MatMul(False, False)
_matmul_T = P.MatMul(False, True)
_matmul_t = P.MatMul(False, True)
_reduce_sum_default = P.ReduceSum()
_reduce_sum_keepdims = P.ReduceSum(True)
_reduce_min_default = P.ReduceMin()
@ -63,6 +63,7 @@ _cumprod_default = P.CumProd()
_round = P.Round()
def absolute(x, dtype=None):
"""
Calculates the absolute value element-wise.
@ -669,7 +670,7 @@ def inner(a, b):
a_aligned = F.reshape(a, aligned_shape_a)
b_aligned = F.reshape(b, aligned_shape_b)
res = _matmul_T(a_aligned, b_aligned)
res = _matmul_t(a_aligned, b_aligned)
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
return res
@ -733,7 +734,7 @@ def dot(a, b):
a_aligned = F.reshape(a, (-1, F.shape(a)[-1]))
b_aligned = F.reshape(b, (-1, F.shape(b)[-1]))
res = _matmul_T(a_aligned, b_aligned)
res = _matmul_t(a_aligned, b_aligned)
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
return res
@ -1008,21 +1009,10 @@ def average(x, axis=None, weights=None, returned=False):
x_avg = full((), nan, F.dtype(x))
sum_of_weights = None
if weights is None:
x_avg = mean(x, axis)
if axis is None:
sum_of_weights = full((), x.size, F.dtype(x))
else:
fill_value = 1
if isinstance(axis, int) or (isinstance(axis, tuple) and F.tuple_len(axis) == 1):
fill_value = x.shape[axis] if isinstance(axis, int) else x.shape[axis[0]]
elif axis is None:
for sh in x.shape:
fill_value *= sh
else:
for ax in axis:
fill_value *= x.shape[ax]
sum_of_weights = full_like(x_avg, fill_value, F.dtype(x))
sum_of_weights = compute_weights_for_mean(x, x_avg, axis)
else:
_check_input_tensor(weights)
if x.shape == weights.shape:
@ -1043,6 +1033,24 @@ def average(x, axis=None, weights=None, returned=False):
return x_avg
def compute_weights_for_mean(x, x_avg, axis):
"""Computes weights for np.average."""
if axis is None:
sum_of_weights = full((), x.size, F.dtype(x))
else:
fill_value = 1
if isinstance(axis, int) or (isinstance(axis, tuple) and F.tuple_len(axis) == 1):
fill_value = x.shape[axis] if isinstance(axis, int) else x.shape[axis[0]]
elif axis is None:
for sh in x.shape:
fill_value *= sh
else:
for ax in axis:
fill_value *= x.shape[ax]
sum_of_weights = full_like(x_avg, fill_value, F.dtype(x))
return sum_of_weights
def comput_avg(x, axis, weights):
"""Computes average value of input x with given parameters."""
axis = () if axis is None else axis
@ -1578,7 +1586,7 @@ def floor_divide(x1, x2, dtype=None):
return _apply_tensor_op(F.tensor_floordiv, x1, x2, dtype=dtype)
def _remainder(x1, x2, C_style=False):
def _remainder(x1, x2, c_style=False):
"""Computes remainder without applying keyword arguments."""
dtype = _promote(F.dtype(x1), F.dtype(x2))
if not _check_is_float(dtype):
@ -1586,7 +1594,7 @@ def _remainder(x1, x2, C_style=False):
x2 = F.cast(x2, mstype.float32)
quotient = F.tensor_div(x1, x2)
if C_style:
if c_style:
quotient = fix(quotient)
else:
quotient = F.floor(quotient)
@ -1671,7 +1679,7 @@ def fix(x):
if not _check_is_float(F.dtype(x)):
x = F.cast(x, mstype.float32)
floored = F.floor(x)
# TODO change to F.ceil once supported on CPU.
# change to F.ceil once supported on CPU.
ceiled = F.neg_tensor(F.floor(F.neg_tensor(x)))
is_neg = F.tensor_lt(x, zeros(F.shape(x), F.dtype(x)))
return F.select(is_neg, ceiled, floored)
@ -1708,7 +1716,7 @@ def fmod(x1, x2, dtype=None):
>>> print(output)
[-1 0 -1 1 0 1]
"""
return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, C_style=True), x1, x2, dtype=dtype)
return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, c_style=True), x1, x2, dtype=dtype)
def trunc(x, dtype=None):
@ -1845,6 +1853,19 @@ def divmod_(x1, x2, dtype=None):
return (q, r)
def _handle_prepend_append(combined, tensor, additional_tensor, axis):
"""Concatenates prepend or append to tensor."""
if isinstance(additional_tensor, (int, float, bool)):
additional_tensor = asarray_const(additional_tensor)
elif not isinstance(additional_tensor, Tensor):
_raise_type_error("prepend must be scalar or Tensor, but got ", additional_tensor)
additional_shape = tensor.shape
additional_shape = _tuple_setitem(additional_shape, axis, 1)
additional_tensor = _broadcast_to_shape(additional_tensor, additional_shape)
combined += (additional_tensor,)
return combined
def diff(a, n=1, axis=-1, prepend=None, append=None):
"""
Calculates the n-th discrete difference along the given axis.
@ -1899,26 +1920,12 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
combined = ()
if prepend is not None:
if isinstance(prepend, (int, float, bool)):
prepend = asarray_const(prepend)
elif not isinstance(prepend, Tensor):
_raise_type_error("prepend must be scalar or Tensor, but got ", prepend)
prepend_shape = a.shape
prepend_shape = _tuple_setitem(prepend_shape, axis, 1)
prepend = _broadcast_to_shape(prepend, prepend_shape)
combined += (prepend,)
combined = _handle_prepend_append(combined, a, prepend, axis)
combined += (a,)
if append is not None:
if isinstance(append, (int, float, bool)):
append = asarray_const(append)
elif not isinstance(append, Tensor):
_raise_type_error("append must be scalar or Tensor, but got ", append)
append_shape = a.shape
append_shape = _tuple_setitem(append_shape, axis, 1)
append = _broadcast_to_shape(append, append_shape)
combined += (append,)
combined = _handle_prepend_append(combined, a, append, axis)
if combined:
a = concatenate(combined, axis)
@ -2239,6 +2246,22 @@ def _handle_inputs(cov_input, rowvar):
return cov_input
def _handle_facts(w, m, ddof, aweights):
"""Computes facts for np.cov"""
fact = None
if w is None:
fact = m.shape[1] - ddof
else:
w_sum = _reduce_sum_default(w, -1)
if ddof == 0:
fact = w_sum
elif aweights is None:
fact = w_sum - ddof
else:
fact = w_sum - ddof * F.reduce_sum(w * aweights) / w_sum
return fact
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None):
"""
Estimates a covariance matrix, given data and weights.
@ -2328,23 +2351,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=N
avg = average(m, axis=1, weights=w)
# Determine the Normalization
if w is None:
fact = m.shape[1] - ddof
else:
w_sum = _reduce_sum_default(w, -1)
if ddof == 0:
fact = w_sum
elif aweights is None:
fact = w_sum - ddof
else:
fact = w_sum - ddof * F.reduce_sum(w * aweights) / w_sum
fact = _handle_facts(w, m, ddof, aweights)
m = m - F.expand_dims(avg, -1)
if w is None:
m_T = m.T
m_t = m.T
else:
m_T = (m * w).T
res = true_divide(dot(m, m_T), fact).squeeze()
m_t = (m * w).T
res = true_divide(dot(m, m_t), fact).squeeze()
if dtype is not None:
return res.astype(dtype)
return res
@ -2417,7 +2431,7 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
if cmp_fn is None:
initial = nan
else:
return _raise_value_error('initial value must be provided for zero-size arrays')
_raise_value_error('initial value must be provided for zero-size arrays')
return full(shape_out, initial, dtype)
if initial is not None:
@ -2426,7 +2440,7 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
if isinstance(where, Tensor):
if initial is None:
return _raise_value_error('initial value must be provided for where masks')
_raise_value_error('initial value must be provided for where masks')
ndim_orig = F.rank(a)
a = where_(where, a, initial)
axes = _real_axes(ndim_orig, F.rank(a), axes)
@ -3277,8 +3291,10 @@ def log2(x, dtype=None):
[1. 2. 3.]
"""
tensor_2 = _make_tensor(2, x.dtype)
def _log2(x):
return F.log(x) / F.log(tensor_2)
return _apply_tensor_op(_log2, x, dtype=dtype)

View File

@ -311,9 +311,12 @@ def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
return True
type_str = ""
if type_int: type_str += "int, "
if type_tuple: type_str += "tuple, "
if type_list: type_str += "list, "
if type_int:
type_str += "int, "
if type_tuple:
type_str += "tuple, "
if type_list:
type_str += "list, "
raise TypeError(f"Axis should be {type_str}but got {type(axis)}.")
@ -451,7 +454,7 @@ def _tuple_setitem(tup, idx, value):
@constexpr
def _iota(dtype, num, increasing=True):
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
# TODO: Change to P.Linspace when the kernel is implemented on CPU.
# Change to P.Linspace when the kernel is implemented on CPU.
if num <= 0:
raise ValueError("zero shape Tensor is not currently supported.")
if increasing: