fix cov error

This commit is contained in:
fengyihang 2023-03-03 11:21:31 +08:00
parent 9dff86b515
commit 1b8726ef81
1 changed files with 11 additions and 5 deletions

View File

@ -1925,6 +1925,13 @@ def _check_cov_weights(weights, weights_name, num_observations, valid_type, vali
return 0
def _get_default_div_type(param):
"""get the default type when div"""
if param.dtype == mstype.float64:
return param
return param.astype(mstype.float32)
def cov(x, *, correction=1, fweights=None, aweights=None):
r"""
Given the input `x` and weights, returns the covariance matrix (the square matrix of the covariance of each pair of
@ -2032,23 +2039,22 @@ def cov(x, *, correction=1, fweights=None, aweights=None):
if w is not None:
w_sum = w.sum()
avg = (input_x * w).sum(1) / w_sum
avg = (input_x * w).sum(1) / _get_default_div_type(w_sum)
else:
w_sum = Tensor(num_observations, dtype=mstype.int64)
avg = input_x.sum(1) / w_sum
avg = input_x.sum(1) / _get_default_div_type(w_sum)
if w is not None and aweights is not None and correction != 0:
norm_factor = w_sum - correction * (w * aweights).sum() / w_sum
else:
norm_factor = w_sum - correction
if norm_factor <= 0:
norm_factor = ops.zeros_like(norm_factor)
norm_factor = norm_factor.clip(min=0)
input_x = input_x - avg.unsqueeze(1)
c = ops.mm(input_x, (input_x * w if w is not None else input_x).T)
norm_factor = norm_factor.astype(mstype.float32)
return ops.true_divide(c, norm_factor).squeeze()
return ops.true_divide(c, _get_default_div_type(norm_factor)).squeeze()
def t(x):