forked from mindspore-Ecosystem/mindspore
fix cov error
This commit is contained in:
parent
9dff86b515
commit
1b8726ef81
|
@ -1925,6 +1925,13 @@ def _check_cov_weights(weights, weights_name, num_observations, valid_type, vali
|
|||
return 0
|
||||
|
||||
|
||||
def _get_default_div_type(param):
|
||||
"""get the default type when div"""
|
||||
if param.dtype == mstype.float64:
|
||||
return param
|
||||
return param.astype(mstype.float32)
|
||||
|
||||
|
||||
def cov(x, *, correction=1, fweights=None, aweights=None):
|
||||
r"""
|
||||
Given the input `x` and weights, returns the covariance matrix (the square matrix of the covariance of each pair of
|
||||
|
@ -2032,23 +2039,22 @@ def cov(x, *, correction=1, fweights=None, aweights=None):
|
|||
|
||||
if w is not None:
|
||||
w_sum = w.sum()
|
||||
avg = (input_x * w).sum(1) / w_sum
|
||||
avg = (input_x * w).sum(1) / _get_default_div_type(w_sum)
|
||||
else:
|
||||
w_sum = Tensor(num_observations, dtype=mstype.int64)
|
||||
avg = input_x.sum(1) / w_sum
|
||||
avg = input_x.sum(1) / _get_default_div_type(w_sum)
|
||||
|
||||
if w is not None and aweights is not None and correction != 0:
|
||||
norm_factor = w_sum - correction * (w * aweights).sum() / w_sum
|
||||
else:
|
||||
norm_factor = w_sum - correction
|
||||
|
||||
if norm_factor <= 0:
|
||||
norm_factor = ops.zeros_like(norm_factor)
|
||||
norm_factor = norm_factor.clip(min=0)
|
||||
|
||||
input_x = input_x - avg.unsqueeze(1)
|
||||
c = ops.mm(input_x, (input_x * w if w is not None else input_x).T)
|
||||
norm_factor = norm_factor.astype(mstype.float32)
|
||||
return ops.true_divide(c, norm_factor).squeeze()
|
||||
return ops.true_divide(c, _get_default_div_type(norm_factor)).squeeze()
|
||||
|
||||
|
||||
def t(x):
|
||||
|
|
Loading…
Reference in New Issue