forked from mindspore-Ecosystem/mindspore
!48251 add api: ops.matrix_power
Merge pull request !48251 from GuoZhibin/add_api_ops_matrix_power
This commit is contained in:
commit
60438fd061
|
@ -381,6 +381,7 @@ Reduction函数
|
||||||
mindspore.ops.matrix_determinant
|
mindspore.ops.matrix_determinant
|
||||||
mindspore.ops.matrix_diag
|
mindspore.ops.matrix_diag
|
||||||
mindspore.ops.matrix_diag_part
|
mindspore.ops.matrix_diag_part
|
||||||
|
mindspore.ops.matrix_power
|
||||||
mindspore.ops.matrix_set_diag
|
mindspore.ops.matrix_set_diag
|
||||||
mindspore.ops.mm
|
mindspore.ops.mm
|
||||||
mindspore.ops.mv
|
mindspore.ops.mv
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
mindspore.ops.matrix_power
|
||||||
|
==========================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.matrix_power(x, n)
|
||||||
|
|
||||||
|
计算一个方阵的(整数)n次幂。
|
||||||
|
如果 :math:`n=0` ,则返回一个单位矩阵,其shape和输入 `x` 一致。
|
||||||
|
如果 :math:`n<0` 且输入矩阵 `x` 可逆,则返回输入矩阵 `x` 的逆矩阵的 :math:`-n` 次幂。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **x** (Tensor) - 一个3-D Tensor。支持的数据类型为float16和float32。
|
||||||
|
shape为 :math:`(b, m, m)` ,表示b个m-D的方阵。
|
||||||
|
- **n** (int) - 指数,必须是整数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
一个3-D Tensor,与 `x` 的shape和数据类型均相同。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果 `n` 的数据类型不是整数。
|
||||||
|
- **TypeError** - 如果 `x` 的数据类型既不是float16,又不是float32。
|
||||||
|
- **TypeError** - 如果 `x` 不是Tensor。
|
||||||
|
- **ValueError** - 如果 `x` 不是一个3-D Tensor。
|
||||||
|
- **ValueError** - 如果 `x` 的shape[1]和shape[2]不同。
|
||||||
|
- **ValueError** - 如果 `n` 为负数,但是输入 `x` 中存在奇异矩阵。
|
|
@ -381,6 +381,7 @@ Linear Algebraic Functions
|
||||||
mindspore.ops.matrix_determinant
|
mindspore.ops.matrix_determinant
|
||||||
mindspore.ops.matrix_diag
|
mindspore.ops.matrix_diag
|
||||||
mindspore.ops.matrix_diag_part
|
mindspore.ops.matrix_diag_part
|
||||||
|
mindspore.ops.matrix_power
|
||||||
mindspore.ops.matrix_set_diag
|
mindspore.ops.matrix_set_diag
|
||||||
mindspore.ops.mm
|
mindspore.ops.mm
|
||||||
mindspore.ops.mv
|
mindspore.ops.mv
|
||||||
|
|
|
@ -388,6 +388,7 @@ from .math_func import (
|
||||||
roll,
|
roll,
|
||||||
orgqr,
|
orgqr,
|
||||||
sum,
|
sum,
|
||||||
|
matrix_power,
|
||||||
matrix_exp,
|
matrix_exp,
|
||||||
logspace
|
logspace
|
||||||
)
|
)
|
||||||
|
|
|
@ -10019,6 +10019,45 @@ def tanhshrink(x):
|
||||||
return x - tanh_op(x)
|
return x - tanh_op(x)
|
||||||
|
|
||||||
|
|
||||||
|
def matrix_power(x, n):
|
||||||
|
"""
|
||||||
|
Raises a square matrix to the (integer) power `n` .
|
||||||
|
|
||||||
|
- When :math:`n=0` , returns the identity matrix, which has the same shape as `x` .
|
||||||
|
- When :math:`n<0` and `x` is invertible, returns the inverse of `x` to the power of :math:`-n` .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): A 3-D Tensor. Supported data types are float16 and float32.
|
||||||
|
The shape is :math:`(b, m, m)` , represents b m-D square matrices.
|
||||||
|
n (int): The exponent, a required int.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 3-D Tensor. Data type and shape are the same as `x` 's.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the data type of `n` is not int.
|
||||||
|
TypeError: If the data type of `x` is neither float32 nor float16.
|
||||||
|
TypeError: If `x` is not a Tensor.
|
||||||
|
ValueError: If `x` is not a 3-D tensor.
|
||||||
|
ValueError: If shape[1] and shape[2] of `x` are not the same.
|
||||||
|
ValueError: If `n` is negative but got input `x` has singular matrices.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor([[[0, 1], [-1, 0]], [[1, 0], [0, -1]]], dtype=ms.float32)
|
||||||
|
>>> y = ops.matrix_power(x, 2)
|
||||||
|
>>> print(y)
|
||||||
|
[[[-1. 0.]
|
||||||
|
[-0. -1.]]
|
||||||
|
[[ 1. 0.]
|
||||||
|
[ 0. 1.]]]
|
||||||
|
"""
|
||||||
|
matrix_power_ops = _get_cache_prim(P.MatrixPower)(n=n)
|
||||||
|
return matrix_power_ops(x)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'addn',
|
'addn',
|
||||||
'absolute',
|
'absolute',
|
||||||
|
@ -10248,6 +10287,7 @@ __all__ = [
|
||||||
'roll',
|
'roll',
|
||||||
'sum',
|
'sum',
|
||||||
'matrix_exp',
|
'matrix_exp',
|
||||||
|
'matrix_power',
|
||||||
'orgqr',
|
'orgqr',
|
||||||
]
|
]
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import ops
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class NetMatrixPower(nn.Cell):
|
||||||
|
def construct(self, x, n):
|
||||||
|
return ops.matrix_power(x, n)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_matrix_power(mode):
|
||||||
|
"""
|
||||||
|
Feature: matrix_power
|
||||||
|
Description: Verify the result of matrix_power
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
arrs = [
|
||||||
|
np.random.rand(1, 2, 2).astype('float32'),
|
||||||
|
np.random.rand(2, 3, 3).astype('float32'),
|
||||||
|
np.random.rand(3, 4, 4).astype('float32'),
|
||||||
|
]
|
||||||
|
net_matrix_power = NetMatrixPower()
|
||||||
|
|
||||||
|
for arr in arrs:
|
||||||
|
for n in range(-2, 4):
|
||||||
|
expect_out = np.linalg.matrix_power(arr, n)
|
||||||
|
out = net_matrix_power(ms.Tensor(arr), n)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)
|
Loading…
Reference in New Issue