!48251 add api: ops.matrix_power

Merge pull request !48251 from GuoZhibin/add_api_ops_matrix_power
This commit is contained in:
i-robot 2023-02-06 12:12:30 +00:00 committed by Gitee
commit 60438fd061
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 117 additions and 0 deletions

View File

@ -381,6 +381,7 @@ Reduction函数
mindspore.ops.matrix_determinant mindspore.ops.matrix_determinant
mindspore.ops.matrix_diag mindspore.ops.matrix_diag
mindspore.ops.matrix_diag_part mindspore.ops.matrix_diag_part
mindspore.ops.matrix_power
mindspore.ops.matrix_set_diag mindspore.ops.matrix_set_diag
mindspore.ops.mm mindspore.ops.mm
mindspore.ops.mv mindspore.ops.mv

View File

@ -0,0 +1,24 @@
mindspore.ops.matrix_power
==========================
.. py:function:: mindspore.ops.matrix_power(x, n)
计算一个方阵的整数n次幂。
如果 :math:`n=0` 则返回一个单位矩阵其shape和输入 `x` 一致。
如果 :math:`n<0` 且输入矩阵 `x` 可逆,则返回输入矩阵 `x` 的逆矩阵的 :math:`-n` 次幂。
参数:
- **x** (Tensor) - 一个3-D Tensor。支持的数据类型为float16和float32。
shape为 :math:`(b, m, m)` 表示b个m-D的方阵。
- **n** (int) - 指数,必须是整数。
返回:
一个3-D Tensor`x` 的shape和数据类型均相同。
异常:
- **TypeError** - 如果 `n` 的数据类型不是整数。
- **TypeError** - 如果 `x` 的数据类型既不是float16又不是float32。
- **TypeError** - 如果 `x` 不是Tensor。
- **ValueError** - 如果 `x` 不是一个3-D Tensor。
- **ValueError** - 如果 `x` 的shape[1]和shape[2]不同。
- **ValueError** - 如果 `n` 为负数,但是输入 `x` 中存在奇异矩阵。

View File

@ -381,6 +381,7 @@ Linear Algebraic Functions
mindspore.ops.matrix_determinant mindspore.ops.matrix_determinant
mindspore.ops.matrix_diag mindspore.ops.matrix_diag
mindspore.ops.matrix_diag_part mindspore.ops.matrix_diag_part
mindspore.ops.matrix_power
mindspore.ops.matrix_set_diag mindspore.ops.matrix_set_diag
mindspore.ops.mm mindspore.ops.mm
mindspore.ops.mv mindspore.ops.mv

View File

@ -388,6 +388,7 @@ from .math_func import (
roll, roll,
orgqr, orgqr,
sum, sum,
matrix_power,
matrix_exp, matrix_exp,
logspace logspace
) )

View File

@ -10019,6 +10019,45 @@ def tanhshrink(x):
return x - tanh_op(x) return x - tanh_op(x)
def matrix_power(x, n):
"""
Raises a square matrix to the (integer) power `n` .
- When :math:`n=0` , returns the identity matrix, which has the same shape as `x` .
- When :math:`n<0` and `x` is invertible, returns the inverse of `x` to the power of :math:`-n` .
Args:
x (Tensor): A 3-D Tensor. Supported data types are float16 and float32.
The shape is :math:`(b, m, m)` , represents b m-D square matrices.
n (int): The exponent, a required int.
Returns:
A 3-D Tensor. Data type and shape are the same as `x` 's.
Raises:
TypeError: If the data type of `n` is not int.
TypeError: If the data type of `x` is neither float32 nor float16.
TypeError: If `x` is not a Tensor.
ValueError: If `x` is not a 3-D tensor.
ValueError: If shape[1] and shape[2] of `x` are not the same.
ValueError: If `n` is negative but got input `x` has singular matrices.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor([[[0, 1], [-1, 0]], [[1, 0], [0, -1]]], dtype=ms.float32)
>>> y = ops.matrix_power(x, 2)
>>> print(y)
[[[-1. 0.]
[-0. -1.]]
[[ 1. 0.]
[ 0. 1.]]]
"""
matrix_power_ops = _get_cache_prim(P.MatrixPower)(n=n)
return matrix_power_ops(x)
__all__ = [ __all__ = [
'addn', 'addn',
'absolute', 'absolute',
@ -10248,6 +10287,7 @@ __all__ = [
'roll', 'roll',
'sum', 'sum',
'matrix_exp', 'matrix_exp',
'matrix_power',
'orgqr', 'orgqr',
] ]
__all__.sort() __all__.sort()

View File

@ -0,0 +1,50 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import ops
import mindspore.nn as nn
class NetMatrixPower(nn.Cell):
def construct(self, x, n):
return ops.matrix_power(x, n)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_matrix_power(mode):
"""
Feature: matrix_power
Description: Verify the result of matrix_power
Expectation: success.
"""
ms.set_context(mode=mode)
arrs = [
np.random.rand(1, 2, 2).astype('float32'),
np.random.rand(2, 3, 3).astype('float32'),
np.random.rand(3, 4, 4).astype('float32'),
]
net_matrix_power = NetMatrixPower()
for arr in arrs:
for n in range(-2, 4):
expect_out = np.linalg.matrix_power(arr, n)
out = net_matrix_power(ms.Tensor(arr), n)
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)