forked from mindspore-Ecosystem/mindspore
fix np.average
This commit is contained in:
parent
2088d38610
commit
03ec134e87
|
@ -25,14 +25,14 @@ from ..common import Tensor
|
||||||
|
|
||||||
from .dtypes import nan, pi
|
from .dtypes import nan, pi
|
||||||
|
|
||||||
from .array_creations import asarray_const, ones, zeros, empty, full
|
from .array_creations import asarray_const, ones, zeros, empty, full, full_like
|
||||||
from .array_ops import where as where_
|
from .array_ops import where as where_
|
||||||
from .array_ops import ravel, expand_dims
|
from .array_ops import ravel, expand_dims
|
||||||
|
|
||||||
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
||||||
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
|
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
|
||||||
_raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \
|
_raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \
|
||||||
_max, _is_shape_empty, _check_is_int
|
_max, _is_shape_empty, _check_is_int, _expanded_shape
|
||||||
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
|
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
|
||||||
_check_input_tensor
|
_check_input_tensor
|
||||||
|
|
||||||
|
@ -1200,22 +1200,38 @@ def average(x, axis=None, weights=None, returned=False):
|
||||||
_check_axis_type(axis, True, True, False)
|
_check_axis_type(axis, True, True, False)
|
||||||
axis = _canonicalize_axis(axis, x.ndim)
|
axis = _canonicalize_axis(axis, x.ndim)
|
||||||
|
|
||||||
if weights is None:
|
|
||||||
return mean(x, axis)
|
|
||||||
|
|
||||||
x_avg = full((), nan, F.dtype(x))
|
x_avg = full((), nan, F.dtype(x))
|
||||||
sum_of_weights = None
|
sum_of_weights = None
|
||||||
|
if weights is None:
|
||||||
|
x_avg = mean(x, axis)
|
||||||
|
if axis is None:
|
||||||
|
sum_of_weights = full((), x.size, F.dtype(x))
|
||||||
|
else:
|
||||||
|
fill_value = 1
|
||||||
|
if isinstance(axis, int) or isinstance(axis, tuple) and F.tuple_len(axis) == 1:
|
||||||
|
fill_value = x.shape[axis]
|
||||||
|
elif axis is None or axis == ():
|
||||||
|
for sh in x.shape:
|
||||||
|
fill_value *= sh
|
||||||
|
else:
|
||||||
|
for ax in axis:
|
||||||
|
fill_value *= x.shape[ax]
|
||||||
|
sum_of_weights = full_like(x_avg, fill_value, F.dtype(x))
|
||||||
|
else:
|
||||||
if x.shape == weights.shape:
|
if x.shape == weights.shape:
|
||||||
x_avg, sum_of_weights = comput_avg(x, axis, weights)
|
x_avg, sum_of_weights = comput_avg(x, axis, weights)
|
||||||
elif F.rank(weights) == 1:
|
elif F.rank(weights) == 1:
|
||||||
if not isinstance(axis, int):
|
if not isinstance(axis, int):
|
||||||
_raise_type_error("Axis must be specified when shapes of x and weights differ.")
|
_raise_type_error("Axis must be specified when shapes of x and weights differ.")
|
||||||
weights = _broadcast_to_shape(weights, x.shape)
|
perm = _expanded_shape(x.ndim, weights.shape[0], axis)
|
||||||
|
weights = weights.reshape(perm)
|
||||||
x_avg, sum_of_weights = comput_avg(x, axis, weights)
|
x_avg, sum_of_weights = comput_avg(x, axis, weights)
|
||||||
else:
|
else:
|
||||||
_raise_type_error("Weights should be None, 1-D or the same as input x, but got shape of", weights)
|
_raise_type_error("Weights should be None, 1-D or the same shape as input x.")
|
||||||
|
|
||||||
if returned:
|
if returned:
|
||||||
|
if x_avg.shape != sum_of_weights.shape:
|
||||||
|
sum_of_weights = _broadcast_to(sum_of_weights, sum_of_weights.shape, x_avg.shape, x_avg.ndim)
|
||||||
return (x_avg, sum_of_weights)
|
return (x_avg, sum_of_weights)
|
||||||
return x_avg
|
return x_avg
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue