forked from mindspore-Ecosystem/mindspore
!48251 add api: ops.matrix_power
Merge pull request !48251 from GuoZhibin/add_api_ops_matrix_power
This commit is contained in:
commit
60438fd061
|
@ -381,6 +381,7 @@ Reduction函数
|
|||
mindspore.ops.matrix_determinant
|
||||
mindspore.ops.matrix_diag
|
||||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.matrix_power
|
||||
mindspore.ops.matrix_set_diag
|
||||
mindspore.ops.mm
|
||||
mindspore.ops.mv
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.matrix_power
|
||||
==========================
|
||||
|
||||
.. py:function:: mindspore.ops.matrix_power(x, n)
|
||||
|
||||
计算一个方阵的(整数)n次幂。
|
||||
如果 :math:`n=0` ,则返回一个单位矩阵,其shape和输入 `x` 一致。
|
||||
如果 :math:`n<0` 且输入矩阵 `x` 可逆,则返回输入矩阵 `x` 的逆矩阵的 :math:`-n` 次幂。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 一个3-D Tensor。支持的数据类型为float16和float32。
|
||||
shape为 :math:`(b, m, m)` ,表示b个m-D的方阵。
|
||||
- **n** (int) - 指数,必须是整数。
|
||||
|
||||
返回:
|
||||
一个3-D Tensor,与 `x` 的shape和数据类型均相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `n` 的数据类型不是整数。
|
||||
- **TypeError** - 如果 `x` 的数据类型既不是float16,又不是float32。
|
||||
- **TypeError** - 如果 `x` 不是Tensor。
|
||||
- **ValueError** - 如果 `x` 不是一个3-D Tensor。
|
||||
- **ValueError** - 如果 `x` 的shape[1]和shape[2]不同。
|
||||
- **ValueError** - 如果 `n` 为负数,但是输入 `x` 中存在奇异矩阵。
|
|
@ -381,6 +381,7 @@ Linear Algebraic Functions
|
|||
mindspore.ops.matrix_determinant
|
||||
mindspore.ops.matrix_diag
|
||||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.matrix_power
|
||||
mindspore.ops.matrix_set_diag
|
||||
mindspore.ops.mm
|
||||
mindspore.ops.mv
|
||||
|
|
|
@ -388,6 +388,7 @@ from .math_func import (
|
|||
roll,
|
||||
orgqr,
|
||||
sum,
|
||||
matrix_power,
|
||||
matrix_exp,
|
||||
logspace
|
||||
)
|
||||
|
|
|
@ -10019,6 +10019,45 @@ def tanhshrink(x):
|
|||
return x - tanh_op(x)
|
||||
|
||||
|
||||
def matrix_power(x, n):
|
||||
"""
|
||||
Raises a square matrix to the (integer) power `n` .
|
||||
|
||||
- When :math:`n=0` , returns the identity matrix, which has the same shape as `x` .
|
||||
- When :math:`n<0` and `x` is invertible, returns the inverse of `x` to the power of :math:`-n` .
|
||||
|
||||
Args:
|
||||
x (Tensor): A 3-D Tensor. Supported data types are float16 and float32.
|
||||
The shape is :math:`(b, m, m)` , represents b m-D square matrices.
|
||||
n (int): The exponent, a required int.
|
||||
|
||||
Returns:
|
||||
A 3-D Tensor. Data type and shape are the same as `x` 's.
|
||||
|
||||
Raises:
|
||||
TypeError: If the data type of `n` is not int.
|
||||
TypeError: If the data type of `x` is neither float32 nor float16.
|
||||
TypeError: If `x` is not a Tensor.
|
||||
ValueError: If `x` is not a 3-D tensor.
|
||||
ValueError: If shape[1] and shape[2] of `x` are not the same.
|
||||
ValueError: If `n` is negative but got input `x` has singular matrices.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[[0, 1], [-1, 0]], [[1, 0], [0, -1]]], dtype=ms.float32)
|
||||
>>> y = ops.matrix_power(x, 2)
|
||||
>>> print(y)
|
||||
[[[-1. 0.]
|
||||
[-0. -1.]]
|
||||
[[ 1. 0.]
|
||||
[ 0. 1.]]]
|
||||
"""
|
||||
matrix_power_ops = _get_cache_prim(P.MatrixPower)(n=n)
|
||||
return matrix_power_ops(x)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'addn',
|
||||
'absolute',
|
||||
|
@ -10248,6 +10287,7 @@ __all__ = [
|
|||
'roll',
|
||||
'sum',
|
||||
'matrix_exp',
|
||||
'matrix_power',
|
||||
'orgqr',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
|
||||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import ops
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class NetMatrixPower(nn.Cell):
|
||||
def construct(self, x, n):
|
||||
return ops.matrix_power(x, n)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_matrix_power(mode):
|
||||
"""
|
||||
Feature: matrix_power
|
||||
Description: Verify the result of matrix_power
|
||||
Expectation: success.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
arrs = [
|
||||
np.random.rand(1, 2, 2).astype('float32'),
|
||||
np.random.rand(2, 3, 3).astype('float32'),
|
||||
np.random.rand(3, 4, 4).astype('float32'),
|
||||
]
|
||||
net_matrix_power = NetMatrixPower()
|
||||
|
||||
for arr in arrs:
|
||||
for n in range(-2, 4):
|
||||
expect_out = np.linalg.matrix_power(arr, n)
|
||||
out = net_matrix_power(ms.Tensor(arr), n)
|
||||
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)
|
Loading…
Reference in New Issue