mindspore.ops.eye和pytorch.eye的参数m不对标问题修正。

1. 对参数m增加默认值None,返回Tensor的列数默认与n相等。
2. 对参数dtype增加默认值None,返回Tensor的数据类型默认为mindspore.float32。
This commit is contained in:
tangdezhi_123 2023-02-22 17:32:35 +08:00
parent 818d2ca40a
commit 5f22e08d05
2 changed files with 27 additions and 9 deletions

View File

@ -1,7 +1,7 @@
mindspore.ops.eye
==================
.. py:function:: mindspore.ops.eye(n, m, t)
.. py:function:: mindspore.ops.eye(n, m=None, dtype=None)
创建一个主对角线上元素为1其余元素为0的Tensor。
@ -10,11 +10,11 @@ mindspore.ops.eye
参数:
- **n** (int) - 指定返回Tensor的行数。仅支持常量值。
- **m** (int) - 指定返回Tensor的列数。仅支持常量值。
- **t** (mindspore.dtype) - 指定返回Tensor的数据类型。数据类型必须是 `bool_ <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_`number <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_
- **m** (int) - 指定返回Tensor的列数。仅支持常量值。默认值为None返回Tensor的列数默认与n相等。
- **dtype** (mindspore.dtype) - 指定返回Tensor的数据类型。数据类型必须是 `bool_ <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_`number <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_默认值为None返回Tensor的数据类型默认为mindspore.float32。
返回:
Tensor主对角线上为1其余的元素为0。它的shape由 `n``m` 指定。数据类型由 `t` 指定。
Tensor主对角线上为1其余的元素为0。它的shape由 `n``m` 指定。数据类型由 `dtype` 指定。
异常:
- **TypeError** - `m``n` 不是int。

View File

@ -20,6 +20,7 @@ import builtins
import operator
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr
@ -285,7 +286,7 @@ def cat(tensors, axis=0):
return _concat(tensors)
def eye(n, m, t):
def eye(n, m=None, dtype=None):
"""
Creates a tensor with ones on the diagonal and zeros in the rest.
@ -296,12 +297,13 @@ def eye(n, m, t):
Args:
n (int): The number of rows of returned tensor. Constant value only.
m (int): The number of columns of returned tensor. Constant value only.
t (mindspore.dtype): MindSpore's dtype, the data type of the returned tensor.
The data type can be Number.
Default: if None, the number of columns is as the same as n.
dtype (mindspore.dtype): MindSpore's dtype, the data type of the returned tensor.
The data type can be Number. Default: if None, the data type of the returned tensor is mindspore.float32.
Returns:
Tensor, a tensor with ones on the diagonal and the rest of elements are zero. The shape of `output` depends on
the user's Inputs `n` and `m`. And the data type depends on Inputs `t`.
the user's Inputs `n` and `m`. And the data type depends on Inputs `dtype`.
Raises:
TypeError: If `m` or `n` is not an int.
@ -322,8 +324,24 @@ def eye(n, m, t):
[[1. 0.]]
>>> print(output.dtype)
Float64
>>> output = ops.eye(2, t=mindspore.int32)
>>> print(output)
[[1 0]
[0 1]]
>>> print(output.dtype)
Int32
>>> output = ops.eye(2)
>>> print(output)
[[1. 0.]
[0. 1.]]
>>> print(output.dtype)
Float32
"""
return eye_(n, m, t)
if m is None:
m = n
if dtype is None:
dtype = ms.float32
return eye_(n, m, dtype)
def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None):