!47353 matrix_exp 文档补齐及对外暴露

Merge pull request !47353 from haozhang/matrix_exp
This commit is contained in:
i-robot 2022-12-30 09:00:26 +00:00 committed by Gitee
commit 950f302ea1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 32 additions and 5 deletions

View File

@ -155,11 +155,12 @@ mindspore.ops
mindspore.ops.cross
mindspore.ops.cumprod
mindspore.ops.erfinv
mindspore.ops.less_equal
mindspore.ops.igamma
mindspore.ops.igammac
mindspore.ops.is_floating_point
mindspore.ops.is_complex
mindspore.ops.is_floating_point
mindspore.ops.less_equal
mindspore.ops.matrix_exp
mindspore.ops.pinv
逐元素运算

View File

@ -0,0 +1,23 @@
mindspore.ops.matrix_exp
========================
.. py:function:: mindspore.ops.matrix_exp(x)
计算方阵的矩阵指数。支持batch维输入。
.. math::
matrix\_exp(x) = \sum_{k=0}^{\infty} \frac{1}{k !} x^{k} \in \mathbb{K}^{n \times n}
参数:
- **x** (Tensor) - 输入Tensorshape为 :math:`(*, n, n)` ,其中 `*` 表示0或更多的batch维。
支持数据类型float16、float32、float64、complex64、complex128。
返回:
Tensor其shape和数据类型均与 `x` 相同。
异常:
- **TypeError** - `x` 不为Tensor。
- **TypeError** - `x` 的dtype不属于以下类型float16、float32、float64、complex64、complex128。
- **ValueError** - `x` 的秩小于2。
- **ValueError** - `x` 的最后两维不相等。

View File

@ -156,10 +156,11 @@ Mathematical Functions
mindspore.ops.cross
mindspore.ops.cumprod
mindspore.ops.erfinv
mindspore.ops.less_equal
mindspore.ops.igamma
mindspore.ops.igammac
mindspore.ops.is_floating_point
mindspore.ops.less_equal
mindspore.ops.matrix_exp
mindspore.ops.pinv
Element-by-Element Operations

View File

@ -370,6 +370,7 @@ from .math_func import (
roll,
orgqr,
sum,
matrix_exp
)
from .nn_func import (
adaptive_avg_pool1d,

View File

@ -3080,7 +3080,7 @@ def matrix_exp(x):
Args:
x (Tensor): The shape of tensor is :math:`(*, n, n)` where * is zero or more batch dimensions.
Must be one of the following types: float64, float32, float16, complex64, complex128.
Must be one of the following types: float16, float32, float64, complex64, complex128.
Returns:
Tensor, has the same shape and dtype as the `x`.
@ -9837,6 +9837,7 @@ __all__ = [
'logical_xor',
'imag',
'roll',
'sum'
'sum',
'matrix_exp'
]
__all__.sort()