!49084 Fix cov error

Merge pull request !49084 from 冯一航/fix_cov_error
This commit is contained in:
i-robot 2023-02-20 10:57:39 +00:00 committed by Gitee
commit 1e0c5c719e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 10 additions and 4 deletions

View File

@ -24,6 +24,9 @@ mindspore.ops.cov
.. warning::
`fweights``aweights` 的值不能为负数,负数权重场景结果未定义。
.. note::
当前暂不支持复数。
参数:
- **x** (Tensor) - 一个二维矩阵,或单个变量的标量或一维向量。

View File

@ -1953,6 +1953,9 @@ def cov(x, *, correction=1, fweights=None, aweights=None):
.. warning::
The values of `fweights` and `aweights` cannot be negative, and the negative weight scene result is undefined.
.. note::
Currently, complex number is not supported.
Args:
x (Tensor): A 2D matrix, or a scalar or 1D vector of a single variable
@ -2028,10 +2031,10 @@ def cov(x, *, correction=1, fweights=None, aweights=None):
w = None
if w is not None:
w_sum = w.sum().astype(mstype.float32)
w_sum = w.sum()
avg = (input_x * w).sum(1) / w_sum
else:
w_sum = Tensor(num_observations, dtype=mstype.float32)
w_sum = Tensor(num_observations, dtype=mstype.int64)
avg = input_x.sum(1) / w_sum
if w is not None and aweights is not None and correction != 0:

View File

@ -43,7 +43,7 @@ def test_cov_normal(mode):
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([[0, 2], [1, 1], [2, 0]]).T
x = ms.Tensor([[0., 2.], [1., 1.], [2., 0.]]).T
output1 = net(x)
expect_output1 = np.array([[1., -1.],
[-1., 1.]])

View File

@ -42,7 +42,7 @@ def test_cov_normal(mode):
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([[0, 2], [1, 1], [2, 0]]).T
x = ms.Tensor([[0., 2.], [1., 1.], [2., 0.]]).T
output1 = net(x)
expect_output1 = np.array([[1., -1.],
[-1., 1.]])