fix np.average

This commit is contained in:
wangrao 2021-02-22 15:29:55 +08:00
parent 2088d38610
commit 03ec134e87
1 changed files with 29 additions and 13 deletions

View File

@ -25,14 +25,14 @@ from ..common import Tensor
from .dtypes import nan, pi from .dtypes import nan, pi
from .array_creations import asarray_const, ones, zeros, empty, full from .array_creations import asarray_const, ones, zeros, empty, full, full_like
from .array_ops import where as where_ from .array_ops import where as where_
from .array_ops import ravel, expand_dims from .array_ops import ravel, expand_dims
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
_raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \
_max, _is_shape_empty, _check_is_int _max, _is_shape_empty, _check_is_int, _expanded_shape
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
_check_input_tensor _check_input_tensor
@ -1200,22 +1200,38 @@ def average(x, axis=None, weights=None, returned=False):
_check_axis_type(axis, True, True, False) _check_axis_type(axis, True, True, False)
axis = _canonicalize_axis(axis, x.ndim) axis = _canonicalize_axis(axis, x.ndim)
if weights is None:
return mean(x, axis)
x_avg = full((), nan, F.dtype(x)) x_avg = full((), nan, F.dtype(x))
sum_of_weights = None sum_of_weights = None
if x.shape == weights.shape: if weights is None:
x_avg, sum_of_weights = comput_avg(x, axis, weights) x_avg = mean(x, axis)
elif F.rank(weights) == 1: if axis is None:
if not isinstance(axis, int): sum_of_weights = full((), x.size, F.dtype(x))
_raise_type_error("Axis must be specified when shapes of x and weights differ.") else:
weights = _broadcast_to_shape(weights, x.shape) fill_value = 1
x_avg, sum_of_weights = comput_avg(x, axis, weights) if isinstance(axis, int) or isinstance(axis, tuple) and F.tuple_len(axis) == 1:
fill_value = x.shape[axis]
elif axis is None or axis == ():
for sh in x.shape:
fill_value *= sh
else:
for ax in axis:
fill_value *= x.shape[ax]
sum_of_weights = full_like(x_avg, fill_value, F.dtype(x))
else: else:
_raise_type_error("Weights should be None, 1-D or the same as input x, but got shape of", weights) if x.shape == weights.shape:
x_avg, sum_of_weights = comput_avg(x, axis, weights)
elif F.rank(weights) == 1:
if not isinstance(axis, int):
_raise_type_error("Axis must be specified when shapes of x and weights differ.")
perm = _expanded_shape(x.ndim, weights.shape[0], axis)
weights = weights.reshape(perm)
x_avg, sum_of_weights = comput_avg(x, axis, weights)
else:
_raise_type_error("Weights should be None, 1-D or the same shape as input x.")
if returned: if returned:
if x_avg.shape != sum_of_weights.shape:
sum_of_weights = _broadcast_to(sum_of_weights, sum_of_weights.shape, x_avg.shape, x_avg.ndim)
return (x_avg, sum_of_weights) return (x_avg, sum_of_weights)
return x_avg return x_avg