forked from mindspore-Ecosystem/mindspore
!47658 Support cov function
Merge pull request !47658 from 冯一航/support_cov
This commit is contained in:
commit
fc1f1419e4
|
@ -30,6 +30,7 @@ mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicin
|
|||
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
|
||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
||||
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
||||
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
||||
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
||||
|
|
|
@ -223,6 +223,7 @@ mindspore.ops
|
|||
mindspore.ops.cos
|
||||
mindspore.ops.cosh
|
||||
mindspore.ops.cosine_similarity
|
||||
mindspore.ops.cov
|
||||
mindspore.ops.deg2rad
|
||||
mindspore.ops.digamma
|
||||
mindspore.ops.div
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.cov
|
||||
=====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.cov()
|
||||
|
||||
详情请参考 :func:`mindspore.ops.cov`。
|
|
@ -82,6 +82,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.copysign
|
||||
mindspore.Tensor.cos
|
||||
mindspore.Tensor.cosh
|
||||
mindspore.Tensor.cov
|
||||
mindspore.Tensor.cross
|
||||
mindspore.Tensor.cummax
|
||||
mindspore.Tensor.cummin
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
mindspore.ops.cov
|
||||
==================
|
||||
|
||||
.. py:function:: mindspore.ops.cov(x, *, correction=1, fweights=None, aweights=None)
|
||||
|
||||
估计由输入矩阵给出的变量的协方差矩阵,其中行是变量,列是观察值。
|
||||
|
||||
协方差矩阵是给出每对变量的协方差的方阵。对角线包含每个变量的方差(变量与其自身的协方差)。根据定义,如果 `x` 表示单个变量(标量或一维),则返回其方差。
|
||||
|
||||
变量 :math:`a` 和 :math:`b` 的无偏样本协方差由下式给出:
|
||||
|
||||
.. math::
|
||||
\text{cov}_w(a,b) = \frac{\sum^{N}_{i = 1}(a_{i} - \bar{a})(b_{i} - \bar{b})}{N~-~1}
|
||||
|
||||
其中 :math:`\bar{a}` 和 :math:`\bar{b}` 分别是 :math:`a` 和 :math:`b` 的简单均值。
|
||||
|
||||
如果提供了 `fweights` 和/或 `aweights` ,则计算无偏加权协方差,由下式给出:
|
||||
|
||||
.. math::
|
||||
\text{cov}_w(a,b) = \frac{\sum^{N}_{i = 1}w_i(a_{i} - \mu_a^*)(b_{i} - \mu_b^*)}{\sum^{N}_{i = 1}w_i~-~1}
|
||||
|
||||
其中 :math:`w` 基于提供的 `fweights` 或 `aweights` 中的任意一个参数进行表示,如果两个参数都有提供,则 :math:`w = fweights \times aweights`,并且 :math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` 表示变量的加权平均值。
|
||||
|
||||
.. warning::
|
||||
`fweights` 和 `aweights` 的值不能为负数,负数权重场景结果未定义。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 包含多个变量和观察值的二维矩阵,或表示单个变量的标量或一维向量。
|
||||
|
||||
关键字参数:
|
||||
- **correction** (int,可选) - 样本量和样本自由度之间的差异,默认为Bessel校正 `correction = 1`,即使指定了 `fweights` 和 `aweights` 的情况下它也会返回无偏估计。`correction = 0` 将返回简单平均值。默认值:1。
|
||||
- **fweights** (Tensor, 可选) - 观察向量频率的标量或一维Tensor,表示每个观察应重复的次数。它的numel必须等于输入 `x` 的列数。必须为整型数据类型。若为None则忽略。默认值:None。
|
||||
- **aweights** (Tensor, 可选) - 观察向量权重的标量或一维数组。对于较为重要的观察值,这些相对权重通常较大,而对于相对不够重要的观察值,这些相对权重较小。它的numel必须等于输入 `x` 的列数。必须为浮点数据类型。若为None则忽略。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,入参的协方差矩阵计算结果。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 如果输入的维度大于2。
|
||||
- **ValueError** - 如果 `fweights` 的维度大于1。
|
||||
- **ValueError** - 如果 `fweights` 的numel不等于输入 `x` 的列数。
|
||||
- **ValueError** - 如果 `aweights` 的numel不等于输入 `x` 的列数。
|
||||
- **ValueError** - 如果 `aweights` 的维度大于1。
|
||||
- **TypeError** - 如果输入的类型为bool类型。
|
||||
- **TypeError** - 如果 `fweights` 的类型不为int。
|
||||
- **TypeError** - 如果 `aweights` 的类型不为浮点类型。
|
|
@ -88,6 +88,7 @@
|
|||
mindspore.Tensor.copysign
|
||||
mindspore.Tensor.cos
|
||||
mindspore.Tensor.cosh
|
||||
mindspore.Tensor.cov
|
||||
mindspore.Tensor.cross
|
||||
mindspore.Tensor.cummax
|
||||
mindspore.Tensor.cummin
|
||||
|
|
|
@ -223,6 +223,7 @@ Element-by-Element Operations
|
|||
mindspore.ops.cos
|
||||
mindspore.ops.cosh
|
||||
mindspore.ops.cosine_similarity
|
||||
mindspore.ops.cov
|
||||
mindspore.ops.deg2rad
|
||||
mindspore.ops.digamma
|
||||
mindspore.ops.div
|
||||
|
|
|
@ -374,6 +374,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"top_k", std::string("top_k")}, // P.TopK()
|
||||
{"isfinite", std::string("isfinite")}, // P.isfinite()
|
||||
{"cos", std::string("cos")}, // cos()
|
||||
{"cov", std::string("cov")}, // cov()
|
||||
{"acos", std::string("acos")}, // acos()
|
||||
{"arccos", std::string("acos")}, // acos()
|
||||
{"acosh", std::string("acosh")}, // acosh()
|
||||
|
|
|
@ -3727,6 +3727,13 @@ def cos(x):
|
|||
return F.cos(x)
|
||||
|
||||
|
||||
def cov(x, *, correction=1, fweights=None, aweights=None):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.cov`.
|
||||
"""
|
||||
return F.cov(x, correction=correction, fweights=fweights, aweights=aweights)
|
||||
|
||||
|
||||
def acos(x):
|
||||
r"""
|
||||
Computes arccosine of input tensors element-wise.
|
||||
|
|
|
@ -1232,6 +1232,13 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('cos')(self)
|
||||
|
||||
def cov(self, *, correction=1, fweights=None, aweights=None):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.cov`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('cov')(self, correction=correction, fweights=fweights, aweights=aweights)
|
||||
|
||||
def acosh(self):
|
||||
r"""
|
||||
Computes inverse hyperbolic cosine of the inputs element-wise.
|
||||
|
|
|
@ -159,6 +159,7 @@ from .math_func import (
|
|||
bincount,
|
||||
tensor_add,
|
||||
cosine_similarity,
|
||||
cov,
|
||||
add,
|
||||
addcdiv,
|
||||
addcmul,
|
||||
|
|
|
@ -1894,6 +1894,146 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-08):
|
|||
return output
|
||||
|
||||
|
||||
def _check_cov_weights(weights, weights_name, num_observations, valid_type, valid_type_name):
|
||||
"""check cov weights valid"""
|
||||
if weights.ndim > 1:
|
||||
raise ValueError(
|
||||
f"For cov, the {weights_name} must have one or fewer dimensions, but got {weights.ndim} dimensions.")
|
||||
if weights.dtype not in valid_type:
|
||||
raise TypeError(
|
||||
f"For cov, the dtype of {weights_name} must be {valid_type_name} type, but got type {weights.dtype}")
|
||||
if ops.numel(weights) != num_observations:
|
||||
raise ValueError(
|
||||
f"For cov, the numel of {weights_name} must equal the number of columns of input, "
|
||||
f"but got numel:{ops.numel(weights)}, number of columns of input:{num_observations}.")
|
||||
return 0
|
||||
|
||||
|
||||
def cov(x, *, correction=1, fweights=None, aweights=None):
|
||||
r"""
|
||||
Estimates the covariance matrix of the variables given by the input matrix,
|
||||
where rows are the variables and columns are the observations.
|
||||
|
||||
A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains
|
||||
the variance of each variable (covariance of a variable with itself). By definition, if `x` represents
|
||||
a single variable (Scalar or 1D) then its variance is returned.
|
||||
|
||||
The unbiased sample covariance of the variables :math:`a` and :math:`b` is given by:
|
||||
|
||||
.. math::
|
||||
\text{cov}_w(a,b) = \frac{\sum^{N}_{i = 1}(a_{i} - \bar{a})(b_{i} - \bar{b})}{N~-~1}
|
||||
|
||||
where :math:`\bar{a}` and :math:`\bar{b}` are the simple means of the :math:`a` and :math:`b` respectively.
|
||||
|
||||
If `fweights` and/or `aweights` are provided, the unbiased weighted covariance
|
||||
is calculated, which is given by:
|
||||
|
||||
.. math::
|
||||
\text{cov}_w(a,b) = \frac{\sum^{N}_{i = 1}w_i(a_{i} - \mu_a^*)(b_{i} - \mu_b^*)}{\sum^{N}_{i = 1}w_i~-~1}
|
||||
|
||||
where :math:`w` denotes `fweights` or `aweights` based on whichever is provided, or
|
||||
:math:`w = fweights \times aweights` if both are provided, and
|
||||
:math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable.
|
||||
|
||||
.. warning::
|
||||
The values of `fweights` and `aweights` cannot be negative, and the negative weight scene result is undefined.
|
||||
|
||||
Args:
|
||||
x (Tensor): A 2D matrix containing multiple variables and observations, or a
|
||||
Scalar or 1D vector representing a single variable.
|
||||
|
||||
Keyword Args:
|
||||
correction (int, optional): difference between the sample size and sample degrees of freedom.
|
||||
Defaults to Bessel's correction, `correction = 1` which returns the unbiased estimate,
|
||||
even if both `fweights` and `aweights` are specified. `correction = 0`
|
||||
will return the simple average. Default: 1.
|
||||
fweights (Tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of
|
||||
times each observation should be repeated. Its numel must equal the number of columns of `x`.
|
||||
Must have integer dtype. Ignored if `None`. Default: None.
|
||||
aweights (Tensor, optional): A Scalar or 1D array of observation vector weights.
|
||||
These relative weights are typically large for observations considered "important" and smaller for
|
||||
observations considered less "important". Its numel must equal the number of columns of `x`.
|
||||
Must have floating point dtype. Ignored if `None`. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, the covariance matrix of the variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dimensions of input is greater than 2.
|
||||
ValueError: If the dimensions of fweights is greater than 1.
|
||||
ValueError: If the numel of fweights not equal the number of columns of input.
|
||||
ValueError: If the numel of aweights not equal the number of columns of input.
|
||||
ValueError: If the dimensions of aweights is greater than 1.
|
||||
TypeError: If the dtype of input is bool.
|
||||
TypeError: If the dtype of fweights is not an integer type.
|
||||
TypeError: If the dtype of aweights is not a floating point type.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.ops as ops
|
||||
>>> x = ms.Tensor([[0, 2], [1, 1], [2, 0]]).T
|
||||
>>> print(x)
|
||||
[[0 1 2]
|
||||
[2 1 0]]
|
||||
>>> print(ops.cov(x))
|
||||
[[ 1. -1.]
|
||||
[-1. 1.]]
|
||||
>>> print(ops.cov(x, correction=0))
|
||||
[[ 0.6666667 -0.6666667]
|
||||
[-0.6666667 0.6666667]]
|
||||
>>> fw = ms.Tensor([5, 2, 4], dtype=ms.int64)
|
||||
>>> aw = ms.Tensor([0.4588, 0.9083, 0.7616], ms.float32)
|
||||
>>> print(ops.cov(x, fweights=fw, aweights=aw))
|
||||
[[ 0.81504613 -0.81504613]
|
||||
[-0.81504613 0.81504613]]
|
||||
"""
|
||||
if x.ndim > 2:
|
||||
raise ValueError(f"For cov, the input must have two or fewer dimensions, but got {x.ndim} dimensions.")
|
||||
if x.dtype == mstype.bool_:
|
||||
raise TypeError(f"For cov, the input dtype can not be bool.")
|
||||
|
||||
# View input tensor as 2D
|
||||
input_x = x.view((1, -1)) if x.ndim < 2 else x
|
||||
num_observations = input_x.shape[1]
|
||||
if fweights is not None:
|
||||
_check_cov_weights(fweights, "fweights", num_observations, mstype.int_type, "an integer")
|
||||
|
||||
if aweights is not None:
|
||||
_check_cov_weights(aweights, "aweights", num_observations, mstype.float_type, "a floating point")
|
||||
|
||||
if fweights is not None and aweights is None:
|
||||
w = fweights
|
||||
elif fweights is None and aweights is not None:
|
||||
w = aweights
|
||||
elif fweights is not None and aweights is not None:
|
||||
w = fweights * aweights
|
||||
else:
|
||||
w = None
|
||||
|
||||
if w is not None:
|
||||
w_sum = w.sum().astype(mstype.float32)
|
||||
avg = (input_x * w).sum(1) / w_sum
|
||||
else:
|
||||
w_sum = Tensor(num_observations, dtype=mstype.float32)
|
||||
avg = input_x.sum(1) / w_sum
|
||||
|
||||
if w is not None and aweights is not None and correction != 0:
|
||||
norm_factor = w_sum - correction * (w * aweights).sum() / w_sum
|
||||
else:
|
||||
norm_factor = w_sum - correction
|
||||
|
||||
if norm_factor <= 0:
|
||||
norm_factor = ops.zeros_like(norm_factor)
|
||||
|
||||
input_x = input_x - avg.unsqueeze(1)
|
||||
c = ops.mm(input_x, (input_x * w if w is not None else input_x).T)
|
||||
norm_factor = norm_factor.astype(mstype.float32)
|
||||
return ops.true_divide(c, norm_factor).squeeze()
|
||||
|
||||
|
||||
def t(x):
|
||||
r"""
|
||||
Transposes a 2-D Tensor. 1-D Tensor are returned as it is.
|
||||
|
@ -9882,6 +10022,7 @@ __all__ = [
|
|||
'cholesky_inverse',
|
||||
'conj',
|
||||
'cosine_similarity',
|
||||
'cov',
|
||||
'cross',
|
||||
'einsum',
|
||||
'erfinv',
|
||||
|
|
|
@ -145,6 +145,7 @@ tensor_operator_registry.register('acos', acos)
|
|||
tensor_operator_registry.register('cos', cos)
|
||||
tensor_operator_registry.register('acosh', acosh)
|
||||
tensor_operator_registry.register('cosh', P.Cosh)
|
||||
tensor_operator_registry.register('cov', cov)
|
||||
tensor_operator_registry.register('asin', asin)
|
||||
tensor_operator_registry.register('sin', sin)
|
||||
tensor_operator_registry.register('sinc', sinc)
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, correction=1, fweights=None, aweights=None):
|
||||
output = ops.cov(x, correction=correction, fweights=fweights, aweights=aweights)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_cov_normal(mode):
|
||||
"""
|
||||
Feature: cov
|
||||
Description: Verify the result of cov
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[0, 2], [1, 1], [2, 0]]).T
|
||||
output1 = net(x)
|
||||
expect_output1 = np.array([[1., -1.],
|
||||
[-1., 1.]])
|
||||
output2 = net(x, correction=0)
|
||||
expect_output2 = np.array([[0.6666667, -0.6666667],
|
||||
[-0.6666667, 0.6666667]])
|
||||
fw = ms.Tensor([5, 2, 4], dtype=ms.int64)
|
||||
aw = ms.Tensor([0.4588, 0.9083, 0.7616], ms.float32)
|
||||
output3 = net(x, fweights=fw, aweights=aw)
|
||||
expect_output3 = np.array([[0.81504613, -0.81504613],
|
||||
[-0.81504613, 0.81504613]])
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||
assert np.allclose(output3.asnumpy(), expect_output3)
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, correction=1, fweights=None, aweights=None):
|
||||
output = x.cov(correction=correction, fweights=fweights, aweights=aweights)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_cov_normal(mode):
|
||||
"""
|
||||
Feature: cov
|
||||
Description: Verify the result of cov
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[0, 2], [1, 1], [2, 0]]).T
|
||||
output1 = net(x)
|
||||
expect_output1 = np.array([[1., -1.],
|
||||
[-1., 1.]])
|
||||
output2 = net(x, correction=0)
|
||||
expect_output2 = np.array([[0.6666667, -0.6666667],
|
||||
[-0.6666667, 0.6666667]])
|
||||
fw = ms.Tensor([5, 2, 4], dtype=ms.int64)
|
||||
aw = ms.Tensor([0.4588, 0.9083, 0.7616], ms.float32)
|
||||
output3 = net(x, fweights=fw, aweights=aw)
|
||||
expect_output3 = np.array([[0.81504613, -0.81504613],
|
||||
[-0.81504613, 0.81504613]])
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||
assert np.allclose(output3.asnumpy(), expect_output3)
|
Loading…
Reference in New Issue